feat: Enhance plugin loading and management
- Added module_name parameter to PluginMeta for better module tracking. - Improved documentation for PluginMeta and PluginLoader methods. - Introduced methods for managing loaded plugins: set_loaded_plugin, remove_loaded_plugin, and purge_plugin_modules. - Enhanced dependency resolution in PluginLoader with resolve_dependencies method. - Implemented candidate discovery and loading in PluginLoader. - Added support for plugin reloading with _reload_plugin_by_id in PluginRunner. - Improved error handling and logging throughout the RPCClient and PluginRunner. - Added support for handling hook invocations in PluginRunner. - Refactored plugin registration and unregistration processes for clarity and efficiency.
This commit is contained in:
@@ -174,7 +174,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
if registered_supervisor is not None:
|
||||
try:
|
||||
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
|
||||
reloaded = await registered_supervisor.reload_plugins(
|
||||
plugin_ids=[plugin_name],
|
||||
reason=f"load {plugin_name}",
|
||||
)
|
||||
if reloaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
@@ -186,7 +189,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
for pdir in sv._plugin_dirs:
|
||||
if (pdir / plugin_name).is_dir():
|
||||
try:
|
||||
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
|
||||
reloaded = await sv.reload_plugins(
|
||||
plugin_ids=[plugin_name],
|
||||
reason=f"load {plugin_name}",
|
||||
)
|
||||
if reloaded:
|
||||
return {"success": True, "count": 1}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
@@ -222,7 +228,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
if sv is not None:
|
||||
try:
|
||||
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
|
||||
reloaded = await sv.reload_plugins(
|
||||
plugin_ids=[plugin_name],
|
||||
reason=f"reload {plugin_name}",
|
||||
)
|
||||
if reloaded:
|
||||
return {"success": True}
|
||||
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
@@ -12,67 +12,78 @@ logger = get_logger("plugin_runtime.integration")
|
||||
def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None:
|
||||
"""向指定 Supervisor 注册主程序提供的能力实现。"""
|
||||
cap_service = supervisor.capability_service
|
||||
rpc_server = supervisor.rpc_server
|
||||
|
||||
cap_service.register_capability("send.text", manager._cap_send_text)
|
||||
cap_service.register_capability("send.emoji", manager._cap_send_emoji)
|
||||
cap_service.register_capability("send.image", manager._cap_send_image)
|
||||
cap_service.register_capability("send.command", manager._cap_send_command)
|
||||
cap_service.register_capability("send.custom", manager._cap_send_custom)
|
||||
def _register(name: str, impl: Any) -> None:
|
||||
"""注册单个能力实现及其 RPC 入口。
|
||||
|
||||
cap_service.register_capability("llm.generate", manager._cap_llm_generate)
|
||||
cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
|
||||
cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
|
||||
Args:
|
||||
name: 能力名称。
|
||||
impl: 能力实现函数。
|
||||
"""
|
||||
cap_service.register_capability(name, impl)
|
||||
rpc_server.register_method(name, cap_service.handle_capability_request)
|
||||
|
||||
cap_service.register_capability("config.get", manager._cap_config_get)
|
||||
cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
|
||||
cap_service.register_capability("config.get_all", manager._cap_config_get_all)
|
||||
_register("send.text", manager._cap_send_text)
|
||||
_register("send.emoji", manager._cap_send_emoji)
|
||||
_register("send.image", manager._cap_send_image)
|
||||
_register("send.command", manager._cap_send_command)
|
||||
_register("send.custom", manager._cap_send_custom)
|
||||
|
||||
cap_service.register_capability("database.query", manager._cap_database_query)
|
||||
cap_service.register_capability("database.save", manager._cap_database_save)
|
||||
cap_service.register_capability("database.get", manager._cap_database_get)
|
||||
cap_service.register_capability("database.delete", manager._cap_database_delete)
|
||||
cap_service.register_capability("database.count", manager._cap_database_count)
|
||||
_register("llm.generate", manager._cap_llm_generate)
|
||||
_register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
|
||||
_register("llm.get_available_models", manager._cap_llm_get_available_models)
|
||||
|
||||
cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
|
||||
cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
|
||||
cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
|
||||
cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
|
||||
cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
|
||||
_register("config.get", manager._cap_config_get)
|
||||
_register("config.get_plugin", manager._cap_config_get_plugin)
|
||||
_register("config.get_all", manager._cap_config_get_all)
|
||||
|
||||
cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
|
||||
cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||
cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
|
||||
cap_service.register_capability("message.count_new", manager._cap_message_count_new)
|
||||
cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
|
||||
_register("database.query", manager._cap_database_query)
|
||||
_register("database.save", manager._cap_database_save)
|
||||
_register("database.get", manager._cap_database_get)
|
||||
_register("database.delete", manager._cap_database_delete)
|
||||
_register("database.count", manager._cap_database_count)
|
||||
|
||||
cap_service.register_capability("person.get_id", manager._cap_person_get_id)
|
||||
cap_service.register_capability("person.get_value", manager._cap_person_get_value)
|
||||
cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
|
||||
_register("chat.get_all_streams", manager._cap_chat_get_all_streams)
|
||||
_register("chat.get_group_streams", manager._cap_chat_get_group_streams)
|
||||
_register("chat.get_private_streams", manager._cap_chat_get_private_streams)
|
||||
_register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
|
||||
_register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
|
||||
|
||||
cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
|
||||
cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
|
||||
cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
|
||||
cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
|
||||
cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
|
||||
cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
|
||||
cap_service.register_capability("emoji.register", manager._cap_emoji_register)
|
||||
cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
|
||||
_register("message.get_by_time", manager._cap_message_get_by_time)
|
||||
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
|
||||
_register("message.get_recent", manager._cap_message_get_recent)
|
||||
_register("message.count_new", manager._cap_message_count_new)
|
||||
_register("message.build_readable", manager._cap_message_build_readable)
|
||||
|
||||
cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
|
||||
cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
|
||||
cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
|
||||
_register("person.get_id", manager._cap_person_get_id)
|
||||
_register("person.get_value", manager._cap_person_get_value)
|
||||
_register("person.get_id_by_name", manager._cap_person_get_id_by_name)
|
||||
|
||||
cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
|
||||
_register("emoji.get_by_description", manager._cap_emoji_get_by_description)
|
||||
_register("emoji.get_random", manager._cap_emoji_get_random)
|
||||
_register("emoji.get_count", manager._cap_emoji_get_count)
|
||||
_register("emoji.get_emotions", manager._cap_emoji_get_emotions)
|
||||
_register("emoji.get_all", manager._cap_emoji_get_all)
|
||||
_register("emoji.get_info", manager._cap_emoji_get_info)
|
||||
_register("emoji.register", manager._cap_emoji_register)
|
||||
_register("emoji.delete", manager._cap_emoji_delete)
|
||||
|
||||
cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
cap_service.register_capability("component.enable", manager._cap_component_enable)
|
||||
cap_service.register_capability("component.disable", manager._cap_component_disable)
|
||||
cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
|
||||
cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
|
||||
cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
_register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
|
||||
_register("frequency.set_adjust", manager._cap_frequency_set_adjust)
|
||||
_register("frequency.get_adjust", manager._cap_frequency_get_adjust)
|
||||
|
||||
cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
|
||||
_register("tool.get_definitions", manager._cap_tool_get_definitions)
|
||||
|
||||
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
|
||||
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
|
||||
_register("component.enable", manager._cap_component_enable)
|
||||
_register("component.disable", manager._cap_component_disable)
|
||||
_register("component.load_plugin", manager._cap_component_load_plugin)
|
||||
_register("component.unload_plugin", manager._cap_component_unload_plugin)
|
||||
_register("component.reload_plugin", manager._cap_component_reload_plugin)
|
||||
|
||||
_register("knowledge.search", manager._cap_knowledge_search)
|
||||
logger.debug("已注册全部主程序能力实现")
|
||||
|
||||
@@ -30,6 +30,11 @@ class CapabilityService:
|
||||
"""
|
||||
|
||||
def __init__(self, authorization: "AuthorizationManager") -> None:
|
||||
"""初始化能力服务。
|
||||
|
||||
Args:
|
||||
authorization: 能力授权管理器。
|
||||
"""
|
||||
self._authorization = authorization
|
||||
# capability_name -> implementation
|
||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||
@@ -51,13 +56,19 @@ class CapabilityService:
|
||||
校验权限后调用对应实现。
|
||||
"""
|
||||
plugin_id = envelope.plugin_id
|
||||
payload = envelope.payload if isinstance(envelope.payload, dict) else {}
|
||||
|
||||
try:
|
||||
req = CapabilityRequestPayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}")
|
||||
req = CapabilityRequestPayload.model_validate(payload)
|
||||
capability = req.capability
|
||||
args = req.args
|
||||
except Exception:
|
||||
capability = envelope.method
|
||||
raw_args = payload.get("args", payload)
|
||||
args = raw_args if isinstance(raw_args, dict) else {}
|
||||
|
||||
capability = req.capability
|
||||
if not capability:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, "能力调用缺少 capability")
|
||||
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._authorization.check_capability(plugin_id, capability)
|
||||
@@ -71,7 +82,7 @@ class CapabilityService:
|
||||
|
||||
# 3. 执行
|
||||
try:
|
||||
result = await impl(plugin_id, capability, req.args)
|
||||
result = await impl(plugin_id, capability, args)
|
||||
resp_payload = CapabilityResponsePayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except RPCError as e:
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_runtime.transport.factory import create_transport_server
|
||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
BootstrapPluginPayload,
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
LogBatchPayload,
|
||||
PROTOCOL_VERSION,
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
RunnerReadyPayload,
|
||||
ShutdownPayload,
|
||||
UnregisterPluginPayload,
|
||||
)
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.transport.factory import create_transport_server
|
||||
|
||||
from .authorization import AuthorizationManager
|
||||
from .capability_service import CapabilityService
|
||||
from .rpc_server import RPCServer
|
||||
from .logger_bridge import RunnerLogBridge
|
||||
from .component_registry import ComponentRegistry
|
||||
from .event_dispatcher import EventDispatcher
|
||||
from .hook_dispatcher import HookDispatcher
|
||||
from .logger_bridge import RunnerLogBridge
|
||||
from .message_gateway import MessageGateway
|
||||
from .message_utils import PluginMessageUtils
|
||||
from .rpc_server import RPCServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
@@ -35,7 +41,11 @@ logger = get_logger("plugin_runtime.host.runner_manager")
|
||||
|
||||
|
||||
class PluginRunnerSupervisor:
|
||||
"""插件的Runner管理器,负责管理Runner的生命周期"""
|
||||
"""插件 Runner 监督器。
|
||||
|
||||
负责 Host 侧与单个 Runner 子进程之间的生命周期、内部 RPC、
|
||||
健康检查和插件级重载协调。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,13 +54,24 @@ class PluginRunnerSupervisor:
|
||||
health_check_interval_sec: Optional[float] = None,
|
||||
max_restart_attempts: Optional[int] = None,
|
||||
runner_spawn_timeout_sec: Optional[float] = None,
|
||||
):
|
||||
_cfg = global_config.plugin_runtime
|
||||
self._plugin_dirs: List[Path] = plugin_dirs or []
|
||||
self._health_interval = health_check_interval_sec or _cfg.health_check_interval_sec or 30.0
|
||||
self._runner_spawn_timeout = runner_spawn_timeout_sec or _cfg.runner_spawn_timeout_sec or 30.0
|
||||
) -> None:
|
||||
"""初始化 Supervisor。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
|
||||
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
|
||||
health_check_interval_sec: 健康检查间隔,单位秒。
|
||||
max_restart_attempts: 自动重启 Runner 的最大次数。
|
||||
runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。
|
||||
"""
|
||||
runtime_config = global_config.plugin_runtime
|
||||
self._plugin_dirs: List[Path] = plugin_dirs or []
|
||||
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
|
||||
self._runner_spawn_timeout: float = (
|
||||
runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
|
||||
)
|
||||
self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3
|
||||
|
||||
# 基础设施
|
||||
self._transport = create_transport_server(socket_path=socket_path)
|
||||
self._authorization = AuthorizationManager()
|
||||
self._capability_service = CapabilityService(self._authorization)
|
||||
@@ -58,61 +79,55 @@ class PluginRunnerSupervisor:
|
||||
self._event_dispatcher = EventDispatcher(self._component_registry)
|
||||
self._hook_dispatcher = HookDispatcher(self._component_registry)
|
||||
self._message_gateway = MessageGateway(self._component_registry)
|
||||
|
||||
# 编解码和服务器
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
self._log_bridge = RunnerLogBridge()
|
||||
|
||||
codec = MsgPackCodec()
|
||||
self._rpc_server = RPCServer(transport=self._transport, codec=codec)
|
||||
|
||||
# Runner 子进程
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._max_restart_attempts: int = max_restart_attempts or _cfg.max_restart_attempts or 3
|
||||
self._restart_count: int = 0
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||
self._health_task: Optional[asyncio.Task[None]] = None
|
||||
self._stderr_drain_task: Optional[asyncio.Task[None]] = None
|
||||
self._restart_count: int = 0
|
||||
self._running: bool = False
|
||||
|
||||
# 后台任务
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
# Runner stderr 流排空任务(仅保留 stderr,用于 IPC 建立前的启动日志倒空、致命错误输出等场景)
|
||||
self._stderr_drain_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
# Runner 日志桥(将 Runner 上报的批量日志重放到主进程 Logger)
|
||||
self._log_bridge: RunnerLogBridge = RunnerLogBridge()
|
||||
|
||||
# 注册内部 RPC 方法
|
||||
self._register_internal_methods() # TODO: 完成内部方法注册
|
||||
self._register_internal_methods()
|
||||
|
||||
@property
|
||||
def authorization_manager(self) -> AuthorizationManager:
|
||||
"""返回授权管理器。"""
|
||||
return self._authorization
|
||||
|
||||
@property
|
||||
def capability_service(self) -> CapabilityService:
|
||||
"""返回能力服务。"""
|
||||
return self._capability_service
|
||||
|
||||
@property
|
||||
def component_registry(self) -> ComponentRegistry:
|
||||
"""返回组件注册表。"""
|
||||
return self._component_registry
|
||||
|
||||
@property
|
||||
def event_dispatcher(self) -> EventDispatcher:
|
||||
"""返回事件分发器。"""
|
||||
return self._event_dispatcher
|
||||
|
||||
@property
|
||||
def hook_dispatcher(self) -> HookDispatcher:
|
||||
"""返回 Hook 分发器。"""
|
||||
return self._hook_dispatcher
|
||||
|
||||
@property
|
||||
def message_gateway(self) -> MessageGateway:
|
||||
"""返回消息网关。"""
|
||||
return self._message_gateway
|
||||
|
||||
@property
|
||||
def rpc_server(self) -> RPCServer:
|
||||
"""返回底层 RPC 服务端。"""
|
||||
return self._rpc_server
|
||||
|
||||
async def dispatch_event(
|
||||
@@ -121,11 +136,28 @@ class PluginRunnerSupervisor:
|
||||
message: Optional["SessionMessage"] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional["SessionMessage"]]:
|
||||
"""分发事件到所有对应 handler 的快捷方法。"""
|
||||
"""分发事件到已注册的事件处理器。
|
||||
|
||||
Args:
|
||||
event_type: 事件类型。
|
||||
message: 可选的消息对象。
|
||||
extra_args: 附加参数。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[SessionMessage]]: 是否继续处理,以及插件可能修改后的消息。
|
||||
"""
|
||||
return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
|
||||
|
||||
async def dispatch_hook(self, stage: str, **kwargs):
|
||||
"""分发Hook事件到所有对应 handler 的快捷方法。"""
|
||||
async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""分发 Hook 到已注册的 Hook 处理器。
|
||||
|
||||
Args:
|
||||
stage: Hook 阶段名称。
|
||||
**kwargs: 传递给 Hook 的关键字参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 经 Hook 修改后的参数字典。
|
||||
"""
|
||||
return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
|
||||
|
||||
async def send_message_to_external(
|
||||
@@ -135,60 +167,68 @@ class PluginRunnerSupervisor:
|
||||
enabled_only: bool = True,
|
||||
save_to_db: bool = True,
|
||||
) -> bool:
|
||||
"""发送系统内部消息到外部平台的快捷方法。"""
|
||||
"""通过插件消息网关发送外部消息。
|
||||
|
||||
Args:
|
||||
internal_message: 系统内部消息对象。
|
||||
enabled_only: 是否仅使用启用的网关组件。
|
||||
save_to_db: 发送成功后是否写入数据库。
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功。
|
||||
"""
|
||||
return await self._message_gateway.send_message_to_external(
|
||||
internal_message, self, enabled_only=enabled_only, save_to_db=save_to_db
|
||||
internal_message,
|
||||
self,
|
||||
enabled_only=enabled_only,
|
||||
save_to_db=save_to_db,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Supervisor
|
||||
"""启动 Supervisor。"""
|
||||
if self._running:
|
||||
logger.warning("PluginRunnerSupervisor 已在运行,跳过重复启动")
|
||||
return
|
||||
|
||||
1. 启动 RPC Server
|
||||
2. 拉起 Runner 子进程
|
||||
3. 启动健康检查
|
||||
"""
|
||||
self._running = True
|
||||
self._restart_count = 0
|
||||
self._clear_runner_state()
|
||||
|
||||
# 启动 RPC Server
|
||||
await self._rpc_server.start()
|
||||
# 拉起 Runner 进程
|
||||
await self._spawn_runner()
|
||||
|
||||
# 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪
|
||||
try:
|
||||
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
|
||||
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
|
||||
except TimeoutError:
|
||||
if not self._rpc_server.is_connected:
|
||||
logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败")
|
||||
logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败")
|
||||
else:
|
||||
logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败")
|
||||
logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败")
|
||||
|
||||
# 启动健康检查
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
logger.info("PluginSupervisor 已启动")
|
||||
self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health")
|
||||
logger.info("PluginRunnerSupervisor 已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Supervisor"""
|
||||
"""停止 Supervisor。"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# 停止组件
|
||||
await self._event_dispatcher.stop()
|
||||
await self._hook_dispatcher.stop()
|
||||
|
||||
# 停止健康检查
|
||||
if self._health_task:
|
||||
if self._health_task is not None:
|
||||
self._health_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._health_task
|
||||
self._health_task = None
|
||||
|
||||
# 优雅关停 Runner
|
||||
await self._shutdown_runner()
|
||||
|
||||
# 停止 RPC Server
|
||||
await self._event_dispatcher.stop()
|
||||
await self._hook_dispatcher.stop()
|
||||
await self._shutdown_runner(reason="host_stop")
|
||||
await self._rpc_server.stop()
|
||||
self._clear_runner_state()
|
||||
|
||||
logger.info("PluginSupervisor 已停止")
|
||||
logger.info("PluginRunnerSupervisor 已停止")
|
||||
|
||||
async def invoke_plugin(
|
||||
self,
|
||||
@@ -198,9 +238,17 @@ class PluginRunnerSupervisor:
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""调用插件组件
|
||||
"""调用 Runner 内的插件组件。
|
||||
|
||||
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
component_name: 组件名。
|
||||
args: 调用参数。
|
||||
timeout_ms: RPC 超时时间,单位毫秒。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
return await self._rpc_server.send_request(
|
||||
method,
|
||||
@@ -210,27 +258,421 @@ class PluginRunnerSupervisor:
|
||||
)
|
||||
|
||||
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||
raise NotImplementedError("等待SDK完成") # TODO: 完成对应的调用和请求逻辑
|
||||
"""按插件 ID 触发精确重载。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
reason: 重载原因。
|
||||
|
||||
Returns:
|
||||
bool: 是否重载成功。
|
||||
"""
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.reload",
|
||||
plugin_id=plugin_id,
|
||||
payload={"plugin_id": plugin_id, "reason": reason},
|
||||
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {plugin_id} 重载请求失败: {exc}")
|
||||
return False
|
||||
|
||||
result = ReloadPluginResultPayload.model_validate(response.payload)
|
||||
if not result.success:
|
||||
logger.warning(f"插件 {plugin_id} 重载失败: {result.failed_plugins}")
|
||||
return result.success
|
||||
|
||||
async def reload_plugins(
|
||||
self,
|
||||
plugin_ids: Optional[List[str]] = None,
|
||||
reason: str = "manual",
|
||||
) -> bool:
|
||||
"""批量重载插件。
|
||||
|
||||
Args:
|
||||
plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
|
||||
reason: 重载原因。
|
||||
|
||||
Returns:
|
||||
bool: 是否全部重载成功。
|
||||
"""
|
||||
target_plugin_ids = plugin_ids or list(self._registered_plugins.keys())
|
||||
ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids))
|
||||
success = True
|
||||
|
||||
for plugin_id in ordered_plugin_ids:
|
||||
reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason)
|
||||
success = success and reloaded
|
||||
|
||||
return success
|
||||
|
||||
async def notify_plugin_config_updated(
|
||||
self,
|
||||
plugin_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
config_version: str = "",
|
||||
) -> bool:
|
||||
"""向 Runner 推送插件配置更新。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
config_data: 配置内容。
|
||||
config_version: 配置版本号。
|
||||
|
||||
Returns:
|
||||
bool: 请求是否成功送达并被 Runner 接受。
|
||||
"""
|
||||
payload = ConfigUpdatedPayload(
|
||||
plugin_id=plugin_id,
|
||||
config_version=config_version,
|
||||
config_data=config_data or {},
|
||||
)
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.config_updated",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 配置更新通知失败: {exc}")
|
||||
return False
|
||||
|
||||
return bool(response.payload.get("acknowledged", False))
|
||||
|
||||
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
|
||||
"""等待 Runner 连接上 RPC Server"""
|
||||
"""等待 Runner 建立 RPC 连接。
|
||||
|
||||
async def wait_for_connection():
|
||||
Args:
|
||||
timeout_sec: 超时时间,单位秒。
|
||||
|
||||
Raises:
|
||||
TimeoutError: 在超时时间内 Runner 未完成连接。
|
||||
"""
|
||||
|
||||
async def wait_for_connection() -> None:
|
||||
"""轮询等待 RPC 连接建立。"""
|
||||
while self._running and not self._rpc_server.is_connected:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec)
|
||||
logger.info("Runner 已连接到 RPC Server")
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from e
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from exc
|
||||
|
||||
async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload:
|
||||
"""等待 Runner 完成初始化并上报就绪"""
|
||||
"""等待 Runner 完成启动初始化。
|
||||
|
||||
Args:
|
||||
timeout_sec: 超时时间,单位秒。
|
||||
|
||||
Returns:
|
||||
RunnerReadyPayload: Runner 上报的就绪信息。
|
||||
|
||||
Raises:
|
||||
TimeoutError: 在超时时间内 Runner 未完成初始化。
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec)
|
||||
logger.info("Runner 已完成初始化并上报就绪")
|
||||
return self._runner_ready_payloads
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from e
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc
|
||||
|
||||
def _register_internal_methods(self) -> None:
|
||||
"""注册 Host 侧内部 RPC 方法。"""
|
||||
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
|
||||
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
|
||||
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
|
||||
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
|
||||
self._rpc_server.register_method("plugin.unregister", self._handle_unregister_plugin)
|
||||
self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch)
|
||||
self._rpc_server.register_method("runner.ready", self._handle_runner_ready)
|
||||
|
||||
async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件 bootstrap 请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
try:
|
||||
payload = BootstrapPluginPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
if payload.capabilities_required:
|
||||
self._authorization.register_plugin(payload.plugin_id, payload.capabilities_required)
|
||||
else:
|
||||
self._authorization.revoke_permission_token(payload.plugin_id)
|
||||
|
||||
return envelope.make_response(payload={"accepted": True, "plugin_id": payload.plugin_id})
|
||||
|
||||
async def _handle_register_plugin(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件组件注册请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
try:
|
||||
payload = RegisterPluginPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
[component.model_dump() for component in payload.components],
|
||||
)
|
||||
self._registered_plugins[payload.plugin_id] = payload
|
||||
|
||||
return envelope.make_response(
|
||||
payload={
|
||||
"accepted": True,
|
||||
"plugin_id": payload.plugin_id,
|
||||
"registered_components": registered_count,
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_unregister_plugin(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件注销请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
try:
|
||||
payload = UnregisterPluginPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
self._authorization.revoke_permission_token(payload.plugin_id)
|
||||
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
|
||||
|
||||
return envelope.make_response(
|
||||
payload={
|
||||
"accepted": True,
|
||||
"plugin_id": payload.plugin_id,
|
||||
"reason": payload.reason,
|
||||
"removed_components": removed_components,
|
||||
"removed_registration": removed_registration,
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 Runner 就绪通知。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: RPC 响应信封。
|
||||
"""
|
||||
try:
|
||||
payload = RunnerReadyPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
self._runner_ready_payloads = payload
|
||||
self._runner_ready_events.set()
|
||||
return envelope.make_response(payload={"accepted": True})
|
||||
|
||||
def _build_runner_environment(self) -> Dict[str, str]:
|
||||
"""构建拉起 Runner 所需的环境变量。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
|
||||
"""
|
||||
return {
|
||||
ENV_HOST_VERSION: PROTOCOL_VERSION,
|
||||
ENV_IPC_ADDRESS: self._transport.get_address(),
|
||||
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
|
||||
ENV_SESSION_TOKEN: self._rpc_server.session_token,
|
||||
}
|
||||
|
||||
async def _spawn_runner(self) -> None:
|
||||
"""拉起 Runner 子进程。"""
|
||||
if self._runner_process is not None and self._runner_process.returncode is None:
|
||||
logger.warning("Runner 已在运行,跳过重复拉起")
|
||||
return
|
||||
|
||||
self._clear_runner_state()
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(self._build_runner_environment())
|
||||
|
||||
self._runner_process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
"-m",
|
||||
"src.plugin_runtime.runner.runner_main",
|
||||
env=env,
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
if self._runner_process.stderr is not None:
|
||||
self._stderr_drain_task = asyncio.create_task(
|
||||
self._drain_runner_stderr(self._runner_process.stderr),
|
||||
name="PluginRunnerSupervisor.stderr",
|
||||
)
|
||||
|
||||
logger.info(f"Runner 已拉起,pid={self._runner_process.pid}")
|
||||
|
||||
async def _drain_runner_stderr(self, stream: asyncio.StreamReader) -> None:
|
||||
"""持续排空 Runner 的 stderr。
|
||||
|
||||
Args:
|
||||
stream: Runner 的 stderr 流。
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
return
|
||||
message = line.decode("utf-8", errors="replace").rstrip()
|
||||
if message:
|
||||
logger.warning(f"[runner-stderr] {message}")
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning(f"排空 Runner stderr 失败: {exc}")
|
||||
|
||||
async def _shutdown_runner(self, reason: str = "normal") -> None:
|
||||
"""优雅关闭 Runner 子进程。
|
||||
|
||||
Args:
|
||||
reason: 关停原因。
|
||||
"""
|
||||
process = self._runner_process
|
||||
if process is None:
|
||||
return
|
||||
|
||||
payload = ShutdownPayload(reason=reason)
|
||||
|
||||
if process.returncode is None and self._rpc_server.is_connected:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._rpc_server.send_request(
|
||||
"plugin.prepare_shutdown",
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=payload.drain_timeout_ms,
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await self._rpc_server.send_request(
|
||||
"plugin.shutdown",
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=payload.drain_timeout_ms,
|
||||
)
|
||||
|
||||
if process.returncode is None:
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=max(payload.drain_timeout_ms / 1000.0, 1.0))
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Runner 优雅退出超时,尝试 terminate")
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Runner terminate 超时,尝试 kill")
|
||||
process.kill()
|
||||
with contextlib.suppress(Exception):
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
|
||||
self._runner_process = None
|
||||
|
||||
if self._stderr_drain_task is not None:
|
||||
self._stderr_drain_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._stderr_drain_task
|
||||
self._stderr_drain_task = None
|
||||
|
||||
self._clear_runner_state()
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""周期性检查 Runner 健康状态,并在必要时重启。"""
|
||||
timeout_ms = max(int(self._health_interval * 1000), 1000)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._health_interval)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
process = self._runner_process
|
||||
if process is None or process.returncode is not None:
|
||||
reason = "runner_process_exited" if process is not None else "runner_process_missing"
|
||||
restarted = await self._restart_runner(reason=reason)
|
||||
if not restarted:
|
||||
return
|
||||
continue
|
||||
|
||||
try:
|
||||
response = await self._rpc_server.send_request("plugin.health", timeout_ms=timeout_ms)
|
||||
health = HealthPayload.model_validate(response.payload)
|
||||
if not health.healthy:
|
||||
restarted = await self._restart_runner(reason="health_check_unhealthy")
|
||||
if not restarted:
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except (RPCError, Exception) as exc:
|
||||
logger.warning(f"Runner 健康检查失败: {exc}")
|
||||
restarted = await self._restart_runner(reason="health_check_failed")
|
||||
if not restarted:
|
||||
return
|
||||
|
||||
async def _restart_runner(self, reason: str) -> bool:
|
||||
"""在 Runner 异常时执行整进程级重启。
|
||||
|
||||
Args:
|
||||
reason: 触发重启的原因。
|
||||
|
||||
Returns:
|
||||
bool: 是否重启成功。
|
||||
"""
|
||||
if not self._running:
|
||||
return False
|
||||
|
||||
if self._restart_count >= self._max_restart_attempts:
|
||||
logger.error(f"Runner 自动重启次数已达上限,停止重启。reason={reason}")
|
||||
return False
|
||||
|
||||
self._restart_count += 1
|
||||
logger.warning(f"准备重启 Runner,第 {self._restart_count} 次,reason={reason}")
|
||||
|
||||
await self._shutdown_runner(reason=reason)
|
||||
|
||||
try:
|
||||
await self._spawn_runner()
|
||||
await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout)
|
||||
await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(f"Runner 重启失败: {exc}", exc_info=True)
|
||||
return False
|
||||
|
||||
self._restart_count = 0
|
||||
logger.info("Runner 已成功重启")
|
||||
return True
|
||||
|
||||
def _clear_runner_state(self) -> None:
|
||||
"""清理当前 Runner 对应的 Host 侧注册状态。"""
|
||||
self._authorization.clear()
|
||||
self._component_registry.clear()
|
||||
self._registered_plugins.clear()
|
||||
self._runner_ready_events = asyncio.Event()
|
||||
self._runner_ready_payloads = RunnerReadyPayload()
|
||||
|
||||
|
||||
PluginSupervisor = PluginRunnerSupervisor
|
||||
|
||||
@@ -23,8 +23,10 @@ from src.plugin_runtime.capabilities import (
|
||||
RuntimeDataCapabilityMixin,
|
||||
)
|
||||
from src.plugin_runtime.capabilities.registry import register_capability_impls
|
||||
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
@@ -223,9 +225,9 @@ class PluginRuntimeManager(
|
||||
async def bridge_event(
|
||||
self,
|
||||
event_type_value: str,
|
||||
message_dict: Optional[Dict[str, Any]] = None,
|
||||
message_dict: Optional[MessageDict] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
) -> Tuple[bool, Optional[MessageDict]]:
|
||||
"""将事件分发到所有 Supervisor
|
||||
|
||||
Returns:
|
||||
@@ -235,17 +237,23 @@ class PluginRuntimeManager(
|
||||
return True, None
|
||||
|
||||
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
|
||||
modified: Optional[Dict[str, Any]] = None
|
||||
modified: Optional[MessageDict] = None
|
||||
current_message: Optional["SessionMessage"] = (
|
||||
PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
|
||||
if message_dict is not None
|
||||
else None
|
||||
)
|
||||
|
||||
for sv in self.supervisors:
|
||||
try:
|
||||
cont, mod = await sv.dispatch_event(
|
||||
event_type=new_event_type,
|
||||
message=modified or message_dict,
|
||||
message=current_message,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
if mod is not None:
|
||||
modified = mod
|
||||
current_message = mod
|
||||
modified = PluginMessageUtils._session_message_to_dict(mod)
|
||||
if not cont:
|
||||
return False, modified
|
||||
except Exception as e:
|
||||
@@ -477,7 +485,7 @@ class PluginRuntimeManager(
|
||||
logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}")
|
||||
return
|
||||
|
||||
reload_supervisors: List[Any] = []
|
||||
reload_supervisors: Dict[Any, List[str]] = {}
|
||||
changed_paths = [change.path.resolve() for change in changes]
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
@@ -485,11 +493,13 @@ class PluginRuntimeManager(
|
||||
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
|
||||
if plugin_id is None:
|
||||
continue
|
||||
if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
|
||||
reload_supervisors.append(supervisor)
|
||||
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
|
||||
reload_supervisors.setdefault(supervisor, [])
|
||||
if plugin_id not in reload_supervisors[supervisor]:
|
||||
reload_supervisors[supervisor].append(plugin_id)
|
||||
|
||||
for supervisor in reload_supervisors:
|
||||
await supervisor.reload_plugins(reason="file_watcher")
|
||||
for supervisor, plugin_ids in reload_supervisors.items():
|
||||
await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher")
|
||||
|
||||
if reload_supervisors:
|
||||
self._refresh_plugin_config_watch_subscriptions()
|
||||
|
||||
@@ -144,7 +144,11 @@ class ComponentDeclaration(BaseModel):
|
||||
|
||||
|
||||
class RegisterPluginPayload(BaseModel):
|
||||
"""plugin.register_plugin 请求 payload"""
|
||||
"""插件组件注册请求载荷。
|
||||
|
||||
该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
|
||||
``plugin.register_plugin`` 请求。
|
||||
"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
@@ -248,6 +252,39 @@ class ShutdownPayload(BaseModel):
|
||||
"""排空超时 (ms)"""
|
||||
|
||||
|
||||
class UnregisterPluginPayload(BaseModel):
|
||||
"""插件注销请求载荷。"""
|
||||
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
"""插件 ID"""
|
||||
reason: str = Field(default="manual", description="注销原因")
|
||||
"""注销原因"""
|
||||
|
||||
|
||||
class ReloadPluginPayload(BaseModel):
|
||||
"""插件重载请求载荷。"""
|
||||
|
||||
plugin_id: str = Field(description="目标插件 ID")
|
||||
"""目标插件 ID"""
|
||||
reason: str = Field(default="manual", description="重载原因")
|
||||
"""重载原因"""
|
||||
|
||||
|
||||
class ReloadPluginResultPayload(BaseModel):
|
||||
"""插件重载结果载荷。"""
|
||||
|
||||
success: bool = Field(description="是否重载成功")
|
||||
"""是否重载成功"""
|
||||
requested_plugin_id: str = Field(description="请求重载的插件 ID")
|
||||
"""请求重载的插件 ID"""
|
||||
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
# ====== 日志传输 ======
|
||||
|
||||
|
||||
|
||||
@@ -32,11 +32,22 @@ class PluginMeta:
|
||||
self,
|
||||
plugin_id: str,
|
||||
plugin_dir: str,
|
||||
module_name: str,
|
||||
plugin_instance: Any,
|
||||
manifest: Dict[str, Any],
|
||||
) -> None:
|
||||
"""初始化插件元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
plugin_dir: 插件目录绝对路径。
|
||||
module_name: 插件入口模块名。
|
||||
plugin_instance: 插件实例对象。
|
||||
manifest: 解析后的 manifest 内容。
|
||||
"""
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_dir = plugin_dir
|
||||
self.module_name = module_name
|
||||
self.instance = plugin_instance
|
||||
self.manifest = manifest
|
||||
self.version = manifest.get("version", "1.0.0")
|
||||
@@ -45,6 +56,14 @@ class PluginMeta:
|
||||
|
||||
@staticmethod
|
||||
def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
|
||||
"""从 manifest 中提取依赖列表。
|
||||
|
||||
Args:
|
||||
manifest: 插件 manifest。
|
||||
|
||||
Returns:
|
||||
List[str]: 规范化后的依赖插件 ID 列表。
|
||||
"""
|
||||
raw = manifest.get("dependencies", [])
|
||||
result: List[str] = []
|
||||
for dep in raw:
|
||||
@@ -66,19 +85,24 @@ class PluginLoader:
|
||||
"""
|
||||
|
||||
def __init__(self, host_version: str = "") -> None:
|
||||
"""初始化插件加载器。
|
||||
|
||||
Args:
|
||||
host_version: Host 版本号,用于 manifest 兼容性校验。
|
||||
"""
|
||||
self._loaded_plugins: Dict[str, PluginMeta] = {}
|
||||
self._failed_plugins: Dict[str, str] = {}
|
||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||
self._compat_hook_installed = False
|
||||
|
||||
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
|
||||
"""扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
|
||||
"""扫描多个目录并加载所有插件。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 插件目录列表
|
||||
plugin_dirs: 插件目录列表。
|
||||
|
||||
Returns:
|
||||
成功加载的插件元数据列表(按依赖顺序)
|
||||
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
|
||||
"""
|
||||
candidates, duplicate_candidates = self._discover_candidates(plugin_dirs)
|
||||
self._record_duplicate_candidates(duplicate_candidates)
|
||||
@@ -90,6 +114,18 @@ class PluginLoader:
|
||||
# 第三阶段:按依赖顺序加载
|
||||
return self._load_plugins_in_order(load_order, candidates)
|
||||
|
||||
def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
|
||||
"""扫描插件目录并返回候选插件。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 需要扫描的插件根目录列表。
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
|
||||
候选插件映射和重复插件 ID 冲突映射。
|
||||
"""
|
||||
return self._discover_candidates(plugin_dirs)
|
||||
|
||||
def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
|
||||
"""扫描插件目录并收集候选插件。"""
|
||||
candidates: Dict[str, PluginCandidate] = {}
|
||||
@@ -170,7 +206,6 @@ class PluginLoader:
|
||||
plugin_dir, manifest, plugin_path = candidates[plugin_id]
|
||||
try:
|
||||
if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
results.append(meta)
|
||||
except Exception as e:
|
||||
self._failed_plugins[plugin_id] = str(e)
|
||||
@@ -182,22 +217,109 @@ class PluginLoader:
|
||||
"""获取已加载的插件"""
|
||||
return self._loaded_plugins.get(plugin_id)
|
||||
|
||||
def set_loaded_plugin(self, meta: PluginMeta) -> None:
|
||||
"""登记一个已经完成初始化的插件。
|
||||
|
||||
Args:
|
||||
meta: 待登记的插件元数据。
|
||||
"""
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
|
||||
def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
|
||||
"""移除一个已加载插件的元数据。
|
||||
|
||||
Args:
|
||||
plugin_id: 待移除的插件 ID。
|
||||
|
||||
Returns:
|
||||
Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。
|
||||
"""
|
||||
return self._loaded_plugins.pop(plugin_id, None)
|
||||
|
||||
def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]:
|
||||
"""清理指定插件目录下的模块缓存。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
plugin_dir: 插件目录绝对路径。
|
||||
|
||||
Returns:
|
||||
List[str]: 已从 ``sys.modules`` 中移除的模块名列表。
|
||||
"""
|
||||
removed_modules: List[str] = []
|
||||
plugin_path = Path(plugin_dir).resolve()
|
||||
synthetic_module_name = f"_maibot_plugin_{plugin_id}"
|
||||
|
||||
for module_name, module in list(sys.modules.items()):
|
||||
if module_name == synthetic_module_name:
|
||||
removed_modules.append(module_name)
|
||||
sys.modules.pop(module_name, None)
|
||||
continue
|
||||
|
||||
module_file = getattr(module, "__file__", None)
|
||||
if module_file is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
module_path = Path(module_file).resolve()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if module_path.is_relative_to(plugin_path):
|
||||
removed_modules.append(module_name)
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
importlib.invalidate_caches()
|
||||
return removed_modules
|
||||
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已加载的插件 ID"""
|
||||
return list(self._loaded_plugins.keys())
|
||||
|
||||
@property
|
||||
def failed_plugins(self) -> Dict[str, str]:
|
||||
"""返回当前记录的失败插件原因映射。"""
|
||||
return dict(self._failed_plugins)
|
||||
|
||||
# ──── 依赖解析 ────────────────────────────────────────────
|
||||
|
||||
def resolve_dependencies(
|
||||
self,
|
||||
candidates: Dict[str, PluginCandidate],
|
||||
extra_available: Optional[Set[str]] = None,
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""解析候选插件的依赖顺序。
|
||||
|
||||
Args:
|
||||
candidates: 待加载的候选插件集合。
|
||||
extra_available: 视为已满足的外部依赖插件 ID 集合。
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。
|
||||
"""
|
||||
return self._resolve_dependencies(candidates, extra_available=extra_available)
|
||||
|
||||
def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]:
|
||||
"""加载单个候选插件模块。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
candidate: 候选插件三元组。
|
||||
|
||||
Returns:
|
||||
Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。
|
||||
"""
|
||||
plugin_dir, manifest, plugin_path = candidate
|
||||
return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
|
||||
|
||||
def _resolve_dependencies(
|
||||
self,
|
||||
candidates: Dict[str, PluginCandidate],
|
||||
extra_available: Optional[Set[str]] = None,
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
|
||||
available = set(candidates.keys())
|
||||
satisfied_dependencies = set(extra_available or set())
|
||||
dep_graph: Dict[str, Set[str]] = {}
|
||||
failed: Dict[str, str] = {}
|
||||
|
||||
@@ -212,6 +334,8 @@ class PluginLoader:
|
||||
continue
|
||||
if dep_name in available:
|
||||
resolved.add(dep_name)
|
||||
elif dep_name in satisfied_dependencies:
|
||||
continue
|
||||
else:
|
||||
missing.append(dep_name)
|
||||
if missing:
|
||||
@@ -271,33 +395,39 @@ class PluginLoader:
|
||||
sys.modules[module_name] = module
|
||||
|
||||
plugin_parent_dir = plugin_dir.parent
|
||||
with self._temporary_sys_path_entry(plugin_parent_dir):
|
||||
spec.loader.exec_module(module)
|
||||
try:
|
||||
with self._temporary_sys_path_entry(plugin_parent_dir):
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 优先使用新版 create_plugin 工厂函数
|
||||
create_plugin = getattr(module, "create_plugin", None)
|
||||
if create_plugin is not None:
|
||||
instance = create_plugin()
|
||||
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
# 优先使用新版 create_plugin 工厂函数
|
||||
create_plugin = getattr(module, "create_plugin", None)
|
||||
if create_plugin is not None:
|
||||
instance = create_plugin()
|
||||
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
module_name=module_name,
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
|
||||
# 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
|
||||
instance = self._try_load_legacy_plugin(module, plugin_id)
|
||||
if instance is not None:
|
||||
logger.info(
|
||||
f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
|
||||
)
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
# 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
|
||||
instance = self._try_load_legacy_plugin(module, plugin_id)
|
||||
if instance is not None:
|
||||
logger.info(
|
||||
f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)"
|
||||
)
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=str(plugin_dir),
|
||||
module_name=module_name,
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
except Exception:
|
||||
sys.modules.pop(module_name, None)
|
||||
raise
|
||||
|
||||
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
|
||||
return None
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
"""Runner 端 RPC Client
|
||||
"""Runner 端 RPC 客户端。"""
|
||||
|
||||
负责:
|
||||
1. 连接 Host RPC Server
|
||||
2. 发送握手(runner.hello)
|
||||
3. 发送组件注册请求
|
||||
4. 接收并分发 Host 的调用请求
|
||||
5. 发送能力调用请求到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client
|
||||
|
||||
logger = get_logger("plugin_runtime.runner.rpc_client")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
|
||||
|
||||
def _get_sdk_version() -> str:
|
||||
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
|
||||
"""读取 SDK 版本号。
|
||||
|
||||
Returns:
|
||||
str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。
|
||||
"""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
@@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version()
|
||||
|
||||
|
||||
class RPCClient:
|
||||
"""Runner 端 RPC 客户端
|
||||
|
||||
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
|
||||
"""
|
||||
"""Runner 端 RPC 客户端。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
codec: Optional[Codec] = None,
|
||||
):
|
||||
self._host_address = host_address
|
||||
self._session_token = session_token
|
||||
self._codec = codec or MsgPackCodec()
|
||||
) -> None:
|
||||
"""初始化 RPC 客户端。
|
||||
|
||||
Args:
|
||||
host_address: Host 的 IPC 地址。
|
||||
session_token: 握手用会话令牌。
|
||||
codec: 可选的编解码器实现。
|
||||
"""
|
||||
self._host_address: str = host_address
|
||||
self._session_token: str = session_token
|
||||
self._codec: Codec = codec or MsgPackCodec()
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Optional[Connection] = None
|
||||
self._runner_id = str(uuid.uuid4())
|
||||
self._generation: int = 0
|
||||
|
||||
# 方法处理器注册表(Host 发来的调用)
|
||||
self._runner_id: str = str(uuid.uuid4())
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._recv_task: Optional[asyncio.Task] = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
@property
|
||||
def generation(self) -> int:
|
||||
return self._generation
|
||||
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
|
||||
self._running: bool = False
|
||||
self._recv_task: Optional[asyncio.Task[None]] = None
|
||||
self._background_tasks: Set[asyncio.Task[Any]] = set()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""返回当前连接是否可用。"""
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册方法处理器(处理 Host 发来的请求)"""
|
||||
"""注册 Host -> Runner 的 RPC 处理器。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
handler: 方法处理函数。
|
||||
"""
|
||||
self._method_handlers[method] = handler
|
||||
|
||||
def _require_connection(self) -> Connection:
|
||||
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
|
||||
"""返回当前可用连接。
|
||||
|
||||
Returns:
|
||||
Connection: 当前连接对象。
|
||||
|
||||
Raises:
|
||||
RPCError: 当前未连接到 Host。
|
||||
"""
|
||||
connection = self._connection
|
||||
if connection is None or connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||||
return cast(Connection, connection)
|
||||
|
||||
async def connect_and_handshake(self) -> bool:
|
||||
"""连接 Host 并完成握手
|
||||
"""连接 Host 并完成握手。
|
||||
|
||||
Returns:
|
||||
是否握手成功
|
||||
bool: 是否握手成功。
|
||||
"""
|
||||
client = create_transport_client(self._host_address)
|
||||
self._connection = await client.connect()
|
||||
connection = self._require_connection()
|
||||
|
||||
# 发送 runner.hello
|
||||
hello = HelloPayload(
|
||||
runner_id=self._runner_id,
|
||||
sdk_version=SDK_VERSION,
|
||||
session_token=self._session_token,
|
||||
)
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
@@ -121,33 +121,27 @@ class RPCClient:
|
||||
payload=hello.model_dump(),
|
||||
)
|
||||
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await connection.send_frame(data)
|
||||
await connection.send_frame(self._codec.encode_envelope(envelope))
|
||||
|
||||
# 接收握手响应
|
||||
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
|
||||
resp = self._codec.decode_envelope(resp_data)
|
||||
response = self._codec.decode_envelope(resp_data)
|
||||
resp_payload = HelloResponsePayload.model_validate(response.payload)
|
||||
|
||||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||||
if not resp_payload.accepted:
|
||||
logger.error(f"握手被拒绝: {resp_payload.reason}")
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
await self.disconnect()
|
||||
return False
|
||||
|
||||
self._generation = resp_payload.assigned_generation
|
||||
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
|
||||
|
||||
# 启动消息接收循环
|
||||
logger.info(f"握手成功: host_version={resp_payload.host_version}")
|
||||
self._running = True
|
||||
self._recv_task = asyncio.create_task(self._recv_loop())
|
||||
|
||||
self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv")
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""断开连接"""
|
||||
"""断开与 Host 的连接并清理状态。"""
|
||||
self._running = False
|
||||
if self._recv_task:
|
||||
|
||||
if self._recv_task is not None:
|
||||
self._recv_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._recv_task
|
||||
@@ -160,13 +154,12 @@ class RPCClient:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
if self._connection:
|
||||
if self._connection is not None:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@@ -177,16 +170,27 @@ class RPCClient:
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Host 发送 RPC 请求并等待响应"""
|
||||
connection = self._require_connection()
|
||||
"""向 Host 发送 RPC 请求并等待响应。
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
Args:
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
payload: 请求载荷。
|
||||
timeout_ms: 超时时间,单位毫秒。
|
||||
|
||||
Returns:
|
||||
Envelope: Host 返回的响应信封。
|
||||
|
||||
Raises:
|
||||
RPCError: 发送失败、超时或连接异常。
|
||||
"""
|
||||
connection = self._require_connection()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -196,21 +200,16 @@ class RPCClient:
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await connection.send_frame(data)
|
||||
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||||
await connection.send_frame(self._codec.encode_envelope(envelope))
|
||||
return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
|
||||
except Exception as e:
|
||||
except Exception as exc:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
if isinstance(e, RPCError):
|
||||
if isinstance(exc, RPCError):
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
@@ -218,33 +217,30 @@ class RPCClient:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""向 Host 发送单向事件(fire-and-forget,不等待响应)。
|
||||
"""向 Host 发送单向广播消息。
|
||||
|
||||
Args:
|
||||
method: RPC 方法名,如 "runner.log_batch"。
|
||||
plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。
|
||||
payload: 事件数据。
|
||||
method: RPC 方法名。
|
||||
plugin_id: 目标插件 ID。
|
||||
payload: 广播载荷。
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
connection = self._require_connection()
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.EVENT,
|
||||
message_type=MessageType.BROADCAST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._generation,
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await connection.send_frame(data)
|
||||
await connection.send_frame(self._codec.encode_envelope(envelope))
|
||||
|
||||
async def _recv_loop(self) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and self._connection and not self._connection.is_closed:
|
||||
"""持续接收 Host 发来的消息并分发。"""
|
||||
while self._running and self._connection is not None and not self._connection.is_closed:
|
||||
try:
|
||||
data = await self._connection.recv_frame()
|
||||
except (asyncio.IncompleteReadError, ConnectionError):
|
||||
@@ -252,39 +248,47 @@ class RPCClient:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"接收帧失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"接收帧失败: {exc}")
|
||||
break
|
||||
|
||||
try:
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
except Exception as e:
|
||||
logger.error(f"解码消息失败: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"解码消息失败: {exc}")
|
||||
continue
|
||||
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
|
||||
elif envelope.is_event():
|
||||
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
|
||||
elif envelope.is_broadcast():
|
||||
self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope)))
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的响应"""
|
||||
"""处理 Host 返回的响应。
|
||||
|
||||
Args:
|
||||
envelope: 响应信封。
|
||||
"""
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
if future is None or future.done():
|
||||
return
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的请求(调用插件组件)"""
|
||||
"""处理 Host 发来的请求。
|
||||
|
||||
Args:
|
||||
envelope: 请求信封。
|
||||
"""
|
||||
connection = self._connection
|
||||
if connection is None or connection.is_closed:
|
||||
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
|
||||
return
|
||||
connection = cast(Connection, connection)
|
||||
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
@@ -298,23 +302,34 @@ class RPCClient:
|
||||
try:
|
||||
response = await handler(envelope)
|
||||
await connection.send_frame(self._codec.encode_envelope(response))
|
||||
except RPCError as e:
|
||||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
except RPCError as exc:
|
||||
error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details)
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
except Exception as exc:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的事件"""
|
||||
if handler := self._method_handlers.get(envelope.method):
|
||||
try:
|
||||
await handler(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
async def _handle_broadcast(self, envelope: Envelope) -> None:
|
||||
"""处理 Host 发来的广播事件。
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||
"""保持后台任务强引用,直到其完成或被取消。"""
|
||||
Args:
|
||||
envelope: 广播信封。
|
||||
"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await handler(envelope)
|
||||
except Exception as exc:
|
||||
logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True)
|
||||
|
||||
def _track_background_task(self, task: asyncio.Task[Any]) -> None:
|
||||
"""持有后台任务强引用直到其结束。
|
||||
|
||||
Args:
|
||||
task: 后台任务。
|
||||
"""
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Protocol, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -32,8 +32,11 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
HealthPayload,
|
||||
InvokePayload,
|
||||
InvokeResultPayload,
|
||||
RegisterComponentsPayload,
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
RunnerReadyPayload,
|
||||
UnregisterPluginPayload,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode
|
||||
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
|
||||
@@ -44,7 +47,8 @@ logger = get_logger("plugin_runtime.runner.main")
|
||||
|
||||
|
||||
class _ContextAwarePlugin(Protocol):
|
||||
def _set_context(self, context: Any) -> None: ...
|
||||
def _set_context(self, context: Any) -> None:
|
||||
"""为插件注入上下文对象。"""
|
||||
|
||||
|
||||
def _install_shutdown_signal_handlers(
|
||||
@@ -90,21 +94,29 @@ class PluginRunner:
|
||||
session_token: str,
|
||||
plugin_dirs: List[str],
|
||||
) -> None:
|
||||
"""初始化 Runner。
|
||||
|
||||
Args:
|
||||
host_address: Host 的 IPC 地址。
|
||||
session_token: 握手用会话令牌。
|
||||
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
|
||||
"""
|
||||
self._host_address: str = host_address
|
||||
self._session_token: str = session_token
|
||||
self._plugin_dirs: list[str] = plugin_dirs
|
||||
self._plugin_dirs: List[str] = plugin_dirs
|
||||
|
||||
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
|
||||
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
|
||||
self._start_time: float = time.monotonic()
|
||||
self._shutting_down: bool = False
|
||||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
# IPC 日志 Handler:握手成功后安装,将所有 stdlib logging 转发到 Host
|
||||
self._log_handler: Optional[RunnerIPCLogHandler] = None
|
||||
self._suspended_console_handlers: list[stdlib_logging.Handler] = []
|
||||
self._suspended_console_handlers: List[stdlib_logging.Handler] = []
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Runner 主入口"""
|
||||
"""运行 Runner 主循环。"""
|
||||
# 1. 连接 Host
|
||||
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
|
||||
ok = await self._rpc_client.connect_and_handshake()
|
||||
@@ -123,32 +135,11 @@ class PluginRunner:
|
||||
logger.info(f"已加载 {len(plugins)} 个插件")
|
||||
|
||||
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
|
||||
failed_plugins: set[str] = set()
|
||||
failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
|
||||
for meta in plugins:
|
||||
instance = meta.instance
|
||||
self._inject_context(meta.plugin_id, instance)
|
||||
self._apply_plugin_config(meta)
|
||||
if not await self._bootstrap_plugin(meta):
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
continue
|
||||
if hasattr(instance, "on_load"):
|
||||
try:
|
||||
ret = instance.on_load()
|
||||
if asyncio.iscoroutine(ret):
|
||||
await ret
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True)
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
await self._deactivate_plugin(meta)
|
||||
|
||||
# 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件)
|
||||
for meta in plugins:
|
||||
if meta.plugin_id in failed_plugins:
|
||||
continue
|
||||
ok = await self._register_plugin(meta)
|
||||
ok = await self._activate_plugin(meta)
|
||||
if not ok:
|
||||
failed_plugins.add(meta.plugin_id)
|
||||
await self._deactivate_plugin(meta)
|
||||
|
||||
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
|
||||
await self._notify_ready(successful_plugins, sorted(failed_plugins))
|
||||
@@ -232,7 +223,9 @@ class PluginRunner:
|
||||
bound_plugin_id = plugin_id
|
||||
|
||||
async def _rpc_call(
|
||||
method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
|
||||
method: str,
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
"""桥接 PluginContext.call_capability → RPCClient.send_request。
|
||||
|
||||
@@ -257,7 +250,7 @@ class PluginRunner:
|
||||
cast(_ContextAwarePlugin, instance)._set_context(ctx)
|
||||
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
||||
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""在 Runner 侧为插件实例注入当前插件配置。"""
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "set_plugin_config"):
|
||||
@@ -270,7 +263,7 @@ class PluginRunner:
|
||||
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
|
||||
|
||||
@staticmethod
|
||||
def _load_plugin_config(plugin_dir: str) -> dict[str, Any]:
|
||||
def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]:
|
||||
"""从插件目录读取 config.toml。"""
|
||||
config_path = Path(plugin_dir) / "config.toml"
|
||||
if not config_path.exists():
|
||||
@@ -286,16 +279,18 @@ class PluginRunner:
|
||||
return loaded if isinstance(loaded, dict) else {}
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""注册方法处理器"""
|
||||
"""注册 Host -> Runner 的方法处理器。"""
|
||||
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
|
||||
self._rpc_client.register_method("plugin.health", self._handle_health)
|
||||
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
|
||||
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
|
||||
|
||||
async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool:
|
||||
"""向 Host 同步插件 bootstrap 能力令牌。"""
|
||||
@@ -324,7 +319,14 @@ class PluginRunner:
|
||||
await self._bootstrap_plugin(meta, capabilities_required=[])
|
||||
|
||||
async def _register_plugin(self, meta: PluginMeta) -> bool:
|
||||
"""向 Host 注册单个插件"""
|
||||
"""向 Host 注册单个插件。
|
||||
|
||||
Args:
|
||||
meta: 待注册的插件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功。
|
||||
"""
|
||||
# 收集插件组件声明
|
||||
components: List[ComponentDeclaration] = []
|
||||
instance = meta.instance
|
||||
@@ -341,7 +343,7 @@ class PluginRunner:
|
||||
for comp_info in instance.get_components()
|
||||
)
|
||||
|
||||
reg_payload = RegisterComponentsPayload(
|
||||
reg_payload = RegisterPluginPayload(
|
||||
plugin_id=meta.plugin_id,
|
||||
plugin_version=meta.version,
|
||||
components=components,
|
||||
@@ -361,8 +363,281 @@ class PluginRunner:
|
||||
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
|
||||
return False
|
||||
|
||||
async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
|
||||
"""通知 Host 注销指定插件。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
reason: 注销原因。
|
||||
"""
|
||||
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason)
|
||||
try:
|
||||
await self._rpc_client.send_request(
|
||||
"plugin.unregister",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}")
|
||||
|
||||
async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool:
|
||||
"""执行插件的 ``on_load`` 生命周期。
|
||||
|
||||
Args:
|
||||
meta: 待初始化的插件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 生命周期是否执行成功。
|
||||
"""
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "on_load"):
|
||||
return True
|
||||
|
||||
try:
|
||||
result = instance.on_load()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None:
|
||||
"""执行插件的 ``on_unload`` 生命周期。
|
||||
|
||||
Args:
|
||||
meta: 待卸载的插件元数据。
|
||||
"""
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "on_unload"):
|
||||
return
|
||||
|
||||
try:
|
||||
result = instance.on_unload()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
|
||||
|
||||
async def _activate_plugin(self, meta: PluginMeta) -> bool:
|
||||
"""完成插件注入、授权、生命周期和组件注册。
|
||||
|
||||
Args:
|
||||
meta: 待激活的插件元数据。
|
||||
|
||||
Returns:
|
||||
bool: 是否激活成功。
|
||||
"""
|
||||
self._inject_context(meta.plugin_id, meta.instance)
|
||||
self._apply_plugin_config(meta)
|
||||
|
||||
if not await self._bootstrap_plugin(meta):
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
if not await self._invoke_plugin_on_load(meta):
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
if not await self._register_plugin(meta):
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
return False
|
||||
|
||||
self._loader.set_loaded_plugin(meta)
|
||||
return True
|
||||
|
||||
async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None:
|
||||
"""卸载单个插件并清理 Host/Runner 两侧状态。
|
||||
|
||||
Args:
|
||||
meta: 待卸载的插件元数据。
|
||||
reason: 卸载原因。
|
||||
"""
|
||||
await self._invoke_plugin_on_unload(meta)
|
||||
await self._unregister_plugin(meta.plugin_id, reason)
|
||||
await self._deactivate_plugin(meta)
|
||||
self._loader.remove_loaded_plugin(meta.plugin_id)
|
||||
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
|
||||
|
||||
def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
|
||||
"""收集依赖指定插件的所有已加载插件。
|
||||
|
||||
Args:
|
||||
plugin_id: 根插件 ID。
|
||||
|
||||
Returns:
|
||||
Set[str]: 目标插件及其所有反向依赖插件集合。
|
||||
"""
|
||||
impacted_plugins: Set[str] = {plugin_id}
|
||||
changed = True
|
||||
|
||||
while changed:
|
||||
changed = False
|
||||
for loaded_plugin_id in self._loader.list_plugins():
|
||||
if loaded_plugin_id in impacted_plugins:
|
||||
continue
|
||||
|
||||
meta = self._loader.get_plugin(loaded_plugin_id)
|
||||
if meta is None:
|
||||
continue
|
||||
|
||||
if any(dependency in impacted_plugins for dependency in meta.dependencies):
|
||||
impacted_plugins.add(loaded_plugin_id)
|
||||
changed = True
|
||||
|
||||
return impacted_plugins
|
||||
|
||||
def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]:
|
||||
"""构建受影响插件的卸载顺序。
|
||||
|
||||
Args:
|
||||
plugin_ids: 需要卸载的插件集合。
|
||||
|
||||
Returns:
|
||||
List[str]: 依赖方优先的卸载顺序。
|
||||
"""
|
||||
dependency_graph: Dict[str, Set[str]] = {}
|
||||
for plugin_id in plugin_ids:
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta is None:
|
||||
dependency_graph[plugin_id] = set()
|
||||
continue
|
||||
dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
|
||||
|
||||
indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
|
||||
reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
|
||||
|
||||
for plugin_id, dependencies in dependency_graph.items():
|
||||
for dependency in dependencies:
|
||||
reverse_graph.setdefault(dependency, set()).add(plugin_id)
|
||||
|
||||
queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0)
|
||||
load_order: List[str] = []
|
||||
|
||||
while queue:
|
||||
current_plugin_id = queue.pop(0)
|
||||
load_order.append(current_plugin_id)
|
||||
for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())):
|
||||
indegree[dependent_plugin_id] -= 1
|
||||
if indegree[dependent_plugin_id] == 0:
|
||||
queue.append(dependent_plugin_id)
|
||||
queue.sort()
|
||||
|
||||
return list(reversed(load_order))
|
||||
|
||||
async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload:
|
||||
"""按插件 ID 在 Runner 进程内执行精确重载。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
reason: 重载原因。
|
||||
|
||||
Returns:
|
||||
ReloadPluginResultPayload: 结构化重载结果。
|
||||
"""
|
||||
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
|
||||
failed_plugins: Dict[str, str] = {}
|
||||
|
||||
if plugin_id in duplicate_candidates:
|
||||
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
||||
return ReloadPluginResultPayload(
|
||||
success=False,
|
||||
requested_plugin_id=plugin_id,
|
||||
failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"},
|
||||
)
|
||||
|
||||
loaded_plugin_ids = set(self._loader.list_plugins())
|
||||
plugin_is_loaded = plugin_id in loaded_plugin_ids
|
||||
plugin_has_candidate = plugin_id in candidates
|
||||
|
||||
if not plugin_is_loaded and not plugin_has_candidate:
|
||||
return ReloadPluginResultPayload(
|
||||
success=False,
|
||||
requested_plugin_id=plugin_id,
|
||||
failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"},
|
||||
)
|
||||
|
||||
target_plugin_ids: Set[str] = {plugin_id}
|
||||
if plugin_is_loaded:
|
||||
target_plugin_ids = self._collect_reverse_dependents(plugin_id)
|
||||
|
||||
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
|
||||
unloaded_plugins: List[str] = []
|
||||
retained_plugin_ids = loaded_plugin_ids - set(unload_order)
|
||||
|
||||
for unload_plugin_id in unload_order:
|
||||
meta = self._loader.get_plugin(unload_plugin_id)
|
||||
if meta is None:
|
||||
continue
|
||||
await self._unload_plugin(meta, reason=reason)
|
||||
unloaded_plugins.append(unload_plugin_id)
|
||||
|
||||
reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
|
||||
for target_plugin_id in target_plugin_ids:
|
||||
candidate = candidates.get(target_plugin_id)
|
||||
if candidate is None:
|
||||
failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态"
|
||||
continue
|
||||
reload_candidates[target_plugin_id] = candidate
|
||||
|
||||
load_order, dependency_failures = self._loader.resolve_dependencies(
|
||||
reload_candidates,
|
||||
extra_available=retained_plugin_ids,
|
||||
)
|
||||
failed_plugins.update(dependency_failures)
|
||||
|
||||
available_plugins = set(retained_plugin_ids)
|
||||
reloaded_plugins: List[str] = []
|
||||
|
||||
for load_plugin_id in load_order:
|
||||
if load_plugin_id in failed_plugins:
|
||||
continue
|
||||
|
||||
candidate = reload_candidates.get(load_plugin_id)
|
||||
if candidate is None:
|
||||
continue
|
||||
|
||||
_, manifest, _ = candidate
|
||||
dependencies = PluginMeta._extract_dependencies(manifest)
|
||||
missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins]
|
||||
if missing_dependencies:
|
||||
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}"
|
||||
continue
|
||||
|
||||
meta = self._loader.load_candidate(load_plugin_id, candidate)
|
||||
if meta is None:
|
||||
failed_plugins[load_plugin_id] = "插件模块加载失败"
|
||||
continue
|
||||
|
||||
activated = await self._activate_plugin(meta)
|
||||
if not activated:
|
||||
failed_plugins[load_plugin_id] = "插件初始化失败"
|
||||
continue
|
||||
|
||||
available_plugins.add(load_plugin_id)
|
||||
reloaded_plugins.append(load_plugin_id)
|
||||
|
||||
requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins
|
||||
|
||||
return ReloadPluginResultPayload(
|
||||
success=requested_plugin_success,
|
||||
requested_plugin_id=plugin_id,
|
||||
reloaded_plugins=reloaded_plugins,
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
)
|
||||
|
||||
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
|
||||
"""通知 Host 当前 generation 已完成插件初始化。"""
|
||||
"""通知 Host 当前 Runner 已完成插件初始化。
|
||||
|
||||
Args:
|
||||
loaded_plugins: 成功初始化的插件列表。
|
||||
failed_plugins: 初始化失败的插件列表。
|
||||
"""
|
||||
payload = RunnerReadyPayload(
|
||||
loaded_plugins=loaded_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
@@ -487,6 +762,61 @@ class PluginRunner:
|
||||
logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True)
|
||||
return envelope.make_response(payload={"success": False, "continue_processing": True})
|
||||
|
||||
async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 HookHandler 调用请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: 标准化后的 Hook 调用结果。
|
||||
"""
|
||||
try:
|
||||
invoke = InvokePayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
f"插件 {plugin_id} 未加载",
|
||||
)
|
||||
|
||||
instance = meta.instance
|
||||
component_name = invoke.component_name
|
||||
handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
|
||||
if handler_method is None or not callable(handler_method):
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {plugin_id} 无组件: {component_name}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw = (
|
||||
await handler_method(**invoke.args)
|
||||
if inspect.iscoroutinefunction(handler_method)
|
||||
else handler_method(**invoke.args)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
|
||||
return envelope.make_response(payload={"success": False, "continue_processing": True})
|
||||
|
||||
if raw is None:
|
||||
result = {"success": True, "continue_processing": True}
|
||||
elif isinstance(raw, dict):
|
||||
result = {
|
||||
"success": True,
|
||||
"continue_processing": raw.get("continue_processing", True),
|
||||
"modified_kwargs": raw.get("modified_kwargs"),
|
||||
"custom_result": raw.get("custom_result"),
|
||||
}
|
||||
else:
|
||||
result = {"success": True, "continue_processing": True, "custom_result": raw}
|
||||
|
||||
return envelope.make_response(payload=result)
|
||||
|
||||
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 WorkflowStep 调用请求
|
||||
|
||||
@@ -557,15 +887,10 @@ class PluginRunner:
|
||||
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
|
||||
"""处理关停 — 调用所有插件的 on_unload 后退出"""
|
||||
logger.info("收到 shutdown 信号,开始调用 on_unload")
|
||||
for plugin_id in self._loader.list_plugins():
|
||||
for plugin_id in list(self._loader.list_plugins()):
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta and hasattr(meta.instance, "on_unload"):
|
||||
try:
|
||||
ret = meta.instance.on_unload()
|
||||
if asyncio.iscoroutine(ret):
|
||||
await ret
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
|
||||
if meta is not None:
|
||||
await self._unload_plugin(meta, reason="runner_shutdown")
|
||||
self._shutting_down = True
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
@@ -587,6 +912,30 @@ class PluginRunner:
|
||||
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
|
||||
"""处理按插件 ID 的精确重载请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: 结构化重载结果。
|
||||
"""
|
||||
try:
|
||||
payload = ReloadPluginPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
if self._reload_lock.locked():
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_RELOAD_IN_PROGRESS.value,
|
||||
f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行",
|
||||
)
|
||||
|
||||
async with self._reload_lock:
|
||||
result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason)
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
def request_capability(self) -> RPCClient:
|
||||
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
|
||||
return self._rpc_client
|
||||
@@ -652,13 +1001,16 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
|
||||
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
|
||||
|
||||
def find_module(self, fullname, path=None):
|
||||
def find_module(self, fullname: str, path: Any = None) -> Any:
|
||||
"""决定是否拦截指定模块导入。"""
|
||||
return self if self._should_block(fullname) else None
|
||||
|
||||
def load_module(self, fullname):
|
||||
def load_module(self, fullname: str) -> None:
|
||||
"""阻止被拦截模块继续导入。"""
|
||||
raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
|
||||
|
||||
def _should_block(self, fullname: str) -> bool:
|
||||
"""判断给定模块名是否应被阻止导入。"""
|
||||
# 放行非 src.* 的导入、以及 "src" 本身
|
||||
if not fullname.startswith("src.") or fullname == "src":
|
||||
return False
|
||||
@@ -692,6 +1044,7 @@ async def _async_main() -> None:
|
||||
|
||||
# 注册信号处理
|
||||
def _mark_runner_shutting_down() -> None:
|
||||
"""标记 Runner 即将进入关停流程。"""
|
||||
runner._shutting_down = True
|
||||
|
||||
_install_shutdown_signal_handlers(_mark_runner_shutting_down)
|
||||
|
||||
Reference in New Issue
Block a user