From 0c508995ddbfa42020485bdc23f4d5a62b1e54f1 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 21:48:19 +0800 Subject: [PATCH] feat: enhance session ID calculation and plugin management - Updated `calculate_session_id` method in `SessionUtils` to include optional `account_id` and `scope` parameters for more granular session ID generation. - Added new environment variables in `plugin_runtime` for external plugin dependencies and global configuration snapshots. - Introduced methods in `RuntimeComponentManagerProtocol` for loading and reloading plugins globally, accommodating external dependencies. - Enhanced `PluginRunnerSupervisor` to manage external available plugin IDs during plugin reloads. - Implemented dependency extraction and management in `PluginRuntimeManager` to handle cross-supervisor dependencies. - Added tests for session ID calculation and message registration in `ChatManager` to ensure correct behavior with new parameters. --- pytests/test_plugin_runtime.py | 114 ++++--- pytests/utils_test/test_session_utils.py | 42 +++ src/chat/message_receive/bot.py | 16 +- src/chat/message_receive/chat_manager.py | 58 +++- src/common/utils/utils_session.py | 22 +- src/plugin_runtime/__init__.py | 6 + src/plugin_runtime/capabilities/components.py | 136 ++++---- src/plugin_runtime/host/supervisor.py | 86 +++-- src/plugin_runtime/integration.py | 302 +++++++++++++++++- src/plugin_runtime/protocol/envelope.py | 4 + src/plugin_runtime/runner/plugin_loader.py | 9 +- src/plugin_runtime/runner/runner_main.py | 140 +++++++- 12 files changed, 765 insertions(+), 170 deletions(-) create mode 100644 pytests/utils_test/test_session_utils.py diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 9b46f897..e094d85b 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -3,6 +3,7 @@ 验证协议层、传输层、RPC 通信链路的正确性。 """ +from pathlib import Path from types import SimpleNamespace import asyncio @@ -2362,6 +2363,8 @@ class TestIntegration: from src.plugin_runtime import integration as integration_module instances = [] + builtin_dir = Path("builtin") + thirdparty_dir = Path("thirdparty") class FakeCapabilityService: def register_capability(self, name, impl): @@ -2369,11 +2372,18 @@ class TestIntegration: class FakeSupervisor: def __init__(self, plugin_dirs=None, socket_path=None): - self.plugin_dirs = plugin_dirs or [] + self._plugin_dirs = plugin_dirs or [] self.capability_service = FakeCapabilityService() + self.external_plugin_ids = [] self.stopped = False instances.append(self) + def set_external_available_plugin_ids(self, plugin_ids): + self.external_plugin_ids = list(plugin_ids) + + def get_loaded_plugin_ids(self): + return [] + async def start(self): if len(instances) == 2 and self is instances[1]: raise RuntimeError("boom") @@ -2382,10 +2392,10 @@ class TestIntegration: self.stopped = True monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]) + integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir]) ) monkeypatch.setattr( - integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]) + integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir]) ) import src.plugin_runtime.host.supervisor as supervisor_module @@ -2427,8 +2437,11 @@ class TestIntegration: self.reload_reasons = [] self.config_updates = [] - async def reload_plugins(self, plugin_ids=None, reason="manual"): - self.reload_reasons.append((plugin_ids, reason)) + def get_loaded_plugin_ids(self): + return sorted(self._registered_plugins.keys()) + + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): + self.reload_reasons.append((plugin_ids, reason, external_available_plugins or [])) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) @@ -2453,11 +2466,59 @@ class TestIntegration: await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] - assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")] + assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] + @pytest.mark.asyncio + async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch): + from src.plugin_runtime import integration as integration_module + + class FakeRegistration: + def __init__(self, dependencies): + self.dependencies = dependencies + + class FakeSupervisor: + def __init__(self, registrations): + self._registered_plugins = registrations + self.reload_calls = [] + + def get_loaded_plugin_ids(self): + return sorted(self._registered_plugins.keys()) + + async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): + self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or []))) + return True + + builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])}) + third_party_supervisor = FakeSupervisor( + { + "beta": FakeRegistration(["alpha"]), + "gamma": FakeRegistration(["beta"]), + } + ) + + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = builtin_supervisor + manager._third_party_supervisor = third_party_supervisor + warning_messages = [] + + monkeypatch.setattr( + integration_module.logger, + "warning", + lambda message: warning_messages.append(message), + ) + + reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual") + + assert reloaded is True + assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])] + assert third_party_supervisor.reload_calls == [] + assert len(warning_messages) == 1 + assert "beta, gamma" in warning_messages[0] + assert "跨 Supervisor API 调用仍然可用" in warning_messages[0] + @pytest.mark.asyncio async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module @@ -2623,55 +2684,30 @@ class TestIntegration: async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch): from src.plugin_runtime import integration as integration_module - class FakeSupervisor: - def __init__(self): - self._registered_plugins = {"alpha": object()} + manager = integration_module.PluginRuntimeManager() + monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False)) - async def reload_plugins(self, reason="manual"): - return False - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin( + result = await manager._cap_component_reload_plugin( "plugin_a", "component.reload_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False - assert "已回滚" in result["error"] + assert result["error"] == "插件 alpha 热重载失败" @pytest.mark.asyncio async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module - plugin_root = tmp_path / "plugins" - plugin_root.mkdir() - (plugin_root / "alpha").mkdir() + manager = integration_module.PluginRuntimeManager() + monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False)) - class FakeSupervisor: - def __init__(self): - self._registered_plugins = {} - self._plugin_dirs = [str(plugin_root)] - - async def reload_plugins(self, reason="manual"): - return False - - class FakeManager: - def __init__(self): - self.supervisors = [FakeSupervisor()] - - monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) - - result = await integration_module.PluginRuntimeManager._cap_component_load_plugin( + result = await manager._cap_component_load_plugin( "plugin_a", "component.load_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False - assert "已回滚" in result["error"] + assert result["error"] == "插件 alpha 热重载失败" diff --git a/pytests/utils_test/test_session_utils.py b/pytests/utils_test/test_session_utils.py new file mode 100644 index 00000000..c44e2eba --- /dev/null +++ b/pytests/utils_test/test_session_utils.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + +from src.chat.message_receive.chat_manager import ChatManager +from src.common.utils.utils_session import SessionUtils + + +def test_calculate_session_id_distinguishes_account_and_scope() -> None: + base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") + same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42") + account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123") + route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main") + + assert base_session_id == same_base_session_id + assert account_scoped_session_id != base_session_id + assert route_scoped_session_id != account_scoped_session_id + + +def test_chat_manager_register_message_uses_route_metadata() -> None: + chat_manager = ChatManager() + message = SimpleNamespace( + platform="qq", + session_id="", + message_info=SimpleNamespace( + user_info=SimpleNamespace(user_id="42"), + group_info=SimpleNamespace(group_id="1000"), + additional_config={ + "platform_io_account_id": "123", + "platform_io_scope": "main", + }, + ), + ) + + chat_manager.register_message(message) + + assert message.session_id == SessionUtils.calculate_session_id( + "qq", + user_id="42", + group_id="1000", + account_id="123", + scope="main", + ) + assert chat_manager.last_messages[message.session_id] is message diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 025150fc..1fc4ef53 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -10,6 +10,7 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils +from src.platform_io.route_key_factory import RouteKeyFactory # from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.core.announcement_manager import global_announcement_manager @@ -270,11 +271,18 @@ class ChatBot: try: group_info = message.message_info.group_info user_info = message.message_info.user_info + account_id = None + scope = None + additional_config = message.message_info.additional_config + if isinstance(additional_config, dict): + account_id, scope = RouteKeyFactory.extract_components(additional_config) session_id = SessionUtils.calculate_session_id( message.platform, user_id=message.message_info.user_info.user_id, group_id=group_info.group_id if group_info else None, + account_id=account_id, + scope=scope, ) message.session_id = session_id # 正确初始化session_id @@ -317,7 +325,13 @@ class ChatBot: platform = message.platform user_id = user_info.user_id group_id = group_info.group_id if group_info else None - _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在 + _ = await chat_manager.get_or_create_session( + platform, + user_id, + group_id, + account_id=account_id, + scope=scope, + ) # 确保会话存在 # message.update_chat_stream(chat) diff --git a/src/chat/message_receive/chat_manager.py b/src/chat/message_receive/chat_manager.py index b11d233c..48d89956 100644 --- a/src/chat/message_receive/chat_manager.py +++ b/src/chat/message_receive/chat_manager.py @@ -1,15 +1,16 @@ +import asyncio from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional + from rich.traceback import install from sqlmodel import select -from typing import Optional, TYPE_CHECKING, List, Dict -import asyncio - -from src.common.logger import get_logger from src.common.data_models.chat_session_data_model import MaiChatSession -from src.common.database.database_model import ChatSession from src.common.database.database import get_db_session +from src.common.database.database_model import ChatSession +from src.common.logger import get_logger from src.common.utils.utils_session import SessionUtils +from src.platform_io.route_key_factory import RouteKeyFactory if TYPE_CHECKING: from .message import SessionMessage @@ -82,7 +83,12 @@ class ChatManager: logger.error(f"初始化聊天管理器出现错误: {e}") async def get_or_create_session( - self, platform: str, user_id: str, group_id: Optional[str] = None + self, + platform: str, + user_id: str, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, ) -> BotChatSession: """获取会话,如果不存在则创建一个新会话;一个封装方法。 @@ -90,12 +96,20 @@ class ChatManager: platform: 平台 user_id: 用户ID group_id: 群ID(如果是群聊) + account_id: 平台账号 ID + scope: 路由作用域 Returns: return (BotChatSession) 会话对象 Raises: Exception: 获取或创建会话时发生错误 """ - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) if session := self.get_session_by_session_id(session_id): session.update_active_time() return session @@ -131,7 +145,18 @@ class ChatManager: raise ValueError("消息缺少平台信息") user_id = message.message_info.user_info.user_id group_id = message.message_info.group_info.group_id if message.message_info.group_info else None - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + account_id = None + scope = None + additional_config = message.message_info.additional_config + if isinstance(additional_config, dict): + account_id, scope = RouteKeyFactory.extract_components(additional_config) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) message.session_id = session_id # 确保消息的session_id正确设置 self.last_messages[session_id] = message @@ -188,7 +213,12 @@ class ChatManager: return None def get_session_by_info( - self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None + self, + platform: str, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, ) -> Optional[BotChatSession]: """根据平台、用户ID和群ID获取对应的会话 @@ -196,10 +226,18 @@ class ChatManager: platform: 平台 user_id: 用户ID group_id: 群ID(如果是群聊) + account_id: 平台账号 ID + scope: 路由作用域 Returns: return (Optional[BotChatSession]): 会话对象,如果不存在则返回None """ - session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id) + session_id = SessionUtils.calculate_session_id( + platform, + user_id=user_id, + group_id=group_id, + account_id=account_id, + scope=scope, + ) return self.get_session_by_session_id(session_id) def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]: diff --git a/src/common/utils/utils_session.py b/src/common/utils/utils_session.py index a383f5a2..1b6d8f72 100644 --- a/src/common/utils/utils_session.py +++ b/src/common/utils/utils_session.py @@ -5,13 +5,22 @@ import hashlib class SessionUtils: @staticmethod - def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str: + def calculate_session_id( + platform: str, + *, + user_id: Optional[str] = None, + group_id: Optional[str] = None, + account_id: Optional[str] = None, + scope: Optional[str] = None, + ) -> str: """计算session_id Args: platform: 平台名称 user_id: 用户ID(如果是私聊) group_id: 群ID(如果是群聊) + account_id: 当前平台账号 ID,可选 + scope: 当前路由作用域,可选 Returns: str: 计算得到的会话ID Raises: @@ -19,8 +28,15 @@ class SessionUtils: """ if not user_id and not group_id: raise ValueError("UserID 或 GroupID 必须提供其一") + + route_components = [] + if account_id: + route_components.append(f"account:{account_id}") + if scope: + route_components.append(f"scope:{scope}") + if group_id: - components = [platform, group_id] + components = [platform, *route_components, group_id] else: - components = [platform, user_id, "private"] + components = [platform, *route_components, user_id, "private"] return hashlib.md5("_".join(components).encode()).hexdigest() diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py index a881d399..704ce514 100644 --- a/src/plugin_runtime/__init__.py +++ b/src/plugin_runtime/__init__.py @@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS" ENV_HOST_VERSION = "MAIBOT_HOST_VERSION" """Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验""" + +ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS" +"""Runner 启动时可视为已满足的外部插件依赖列表(JSON 数组)""" + +ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT" +"""Runner 启动时注入的全局配置快照(JSON 对象)""" diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 67033fdd..2e4c111c 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence from src.common.logger import get_logger @@ -15,8 +15,35 @@ class _RuntimeComponentManagerProtocol(Protocol): @property def supervisors(self) -> List["PluginSupervisor"]: ... + def _normalize_component_type(self, component_type: str) -> str: ... + + def _is_api_component_type(self, component_type: str) -> bool: ... + + def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ... + + def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ... + + def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ... + + def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ... + + def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ... + def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ... + def _resolve_api_target( + self, + caller_plugin_id: str, + api_name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ... + + def _resolve_api_toggle_target( + self, + name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ... + def _resolve_component_toggle_target( self, name: str, component_type: str ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ... @@ -25,6 +52,10 @@ class _RuntimeComponentManagerProtocol(Protocol): def _iter_plugin_dirs(self) -> Iterable[Path]: ... + async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ... + + async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ... + class RuntimeComponentCapabilityMixin: @staticmethod @@ -266,20 +297,22 @@ class RuntimeComponentCapabilityMixin: version=normalized_version, enabled_only=False, ) - if not entries: - return None, None, f"未找到 API: {normalized_name}" - if len(entries) > 1: + if len(entries) == 1: + return supervisor, entries[0], None + if entries: return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version" - return supervisor, entries[0], None + return None, None, f"未找到 API: {normalized_name}" matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis( - name=normalized_name, - version=normalized_version, - enabled_only=False, - ): - matches.append((supervisor, entry)) + matches.extend( + (supervisor, entry) + for entry in supervisor.api_registry.get_apis( + name=normalized_name, + version=normalized_version, + enabled_only=False, + ) + ) if len(matches) == 1: return matches[0][0], matches[0][1], None @@ -453,39 +486,14 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} try: - registered_supervisor = self._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} + loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}") + except Exception as e: + logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} - if registered_supervisor is not None: - try: - reloaded = await registered_supervisor.reload_plugins( - plugin_ids=[plugin_name], - reason=f"load {plugin_name}", - ) - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - for sv in self.supervisors: - for pdir in sv._plugin_dirs: - if (pdir / plugin_name).is_dir(): - try: - reloaded = await sv.reload_plugins( - plugin_ids=[plugin_name], - reason=f"load {plugin_name}", - ) - if reloaded: - return {"success": True, "count": 1} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.load_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - - return {"success": False, "error": f"未找到插件: {plugin_name}"} + if loaded: + return {"success": True, "count": 1} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败"} async def _cap_component_unload_plugin( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] @@ -507,23 +515,14 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"检测到重复插件 ID,拒绝热重载: {details}"} try: - sv = self._get_supervisor_for_plugin(plugin_name) - except RuntimeError as exc: - return {"success": False, "error": str(exc)} + reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}") + except Exception as e: + logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") + return {"success": False, "error": str(e)} - if sv is not None: - try: - reloaded = await sv.reload_plugins( - plugin_ids=[plugin_name], - reason=f"reload {plugin_name}", - ) - if reloaded: - return {"success": True} - return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} - except Exception as e: - logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") - return {"success": False, "error": str(e)} - return {"success": False, "error": f"未找到插件: {plugin_name}"} + if reloaded: + return {"success": True} + return {"success": False, "error": f"插件 {plugin_name} 热重载失败"} async def _cap_api_call( self: _RuntimeComponentManagerProtocol, @@ -632,15 +631,16 @@ class RuntimeComponentCapabilityMixin: ) apis: List[Dict[str, Any]] = [] for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis( - plugin_id=target_plugin_id or None, - name=api_name, - version=version, - enabled_only=True, - ): - if not self._is_api_visible_to_plugin(entry, plugin_id): - continue - apis.append(self._serialize_api_entry(entry)) + apis.extend( + self._serialize_api_entry(entry) + for entry in supervisor.api_registry.get_apis( + plugin_id=target_plugin_id or None, + name=api_name, + version=version, + enabled_only=True, + ) + if self._is_api_visible_to_plugin(entry, plugin_id) + ) apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) return {"success": True, "apis": apis} diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index afe944e5..693eae51 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -4,17 +4,26 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import asyncio import contextlib +import json import os import sys from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import config_manager, global_config from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN +from src.plugin_runtime import ( + ENV_EXTERNAL_PLUGIN_IDS, + ENV_GLOBAL_CONFIG_SNAPSHOT, + ENV_HOST_VERSION, + ENV_IPC_ADDRESS, + ENV_PLUGIN_DIRS, + ENV_SESSION_TOKEN, +) from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, + ConfigReloadScope, ConfigUpdatedPayload, Envelope, HealthPayload, @@ -107,6 +116,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {} + self._external_available_plugin_ids: List[str] = [] self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -156,6 +166,21 @@ class PluginRunnerSupervisor: """返回底层 RPC 服务端。""" return self._rpc_server + def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None: + """设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。""" + + normalized_plugin_ids = { + str(plugin_id or "").strip() + for plugin_id in plugin_ids + if str(plugin_id or "").strip() + } + self._external_available_plugin_ids = sorted(normalized_plugin_ids) + + def get_loaded_plugin_ids(self) -> List[str]: + """返回当前 Supervisor 已注册的插件 ID 列表。""" + + return sorted(self._registered_plugins.keys()) + async def dispatch_event( self, event_type: str, @@ -344,12 +369,18 @@ class PluginRunnerSupervisor: timeout_ms=timeout_ms, ) - async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: + async def reload_plugin( + self, + plugin_id: str, + reason: str = "manual", + external_available_plugins: Optional[List[str]] = None, + ) -> bool: """按插件 ID 触发精确重载。 Args: plugin_id: 目标插件 ID。 reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件 ID 列表。 Returns: bool: 是否重载成功。 @@ -358,7 +389,11 @@ class PluginRunnerSupervisor: response = await self._rpc_server.send_request( "plugin.reload", plugin_id=plugin_id, - payload={"plugin_id": plugin_id, "reason": reason}, + payload={ + "plugin_id": plugin_id, + "reason": reason, + "external_available_plugins": external_available_plugins or self._external_available_plugin_ids, + }, timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), ) except Exception as exc: @@ -374,12 +409,14 @@ class PluginRunnerSupervisor: self, plugin_ids: Optional[List[str]] = None, reason: str = "manual", + external_available_plugins: Optional[List[str]] = None, ) -> bool: """批量重载插件。 Args: plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。 reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件 ID 列表。 Returns: bool: 是否全部重载成功。 @@ -389,7 +426,11 @@ class PluginRunnerSupervisor: success = True for plugin_id in ordered_plugin_ids: - reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason) + reloaded = await self.reload_plugin( + plugin_id=plugin_id, + reason=reason, + external_available_plugins=external_available_plugins, + ) success = success and reloaded return success @@ -399,7 +440,7 @@ class PluginRunnerSupervisor: plugin_id: str, config_data: Optional[Dict[str, Any]] = None, config_version: str = "", - config_scope: str = "self", + config_scope: str | ConfigReloadScope = "self", ) -> bool: """向 Runner 推送插件配置更新。 @@ -412,9 +453,15 @@ class PluginRunnerSupervisor: Returns: bool: 请求是否成功送达并被 Runner 接受。 """ + try: + normalized_scope = ConfigReloadScope(config_scope) + except ValueError: + logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}") + return False + payload = ConfigUpdatedPayload( plugin_id=plugin_id, - config_scope=config_scope, + config_scope=normalized_scope, config_version=config_version, config_data=config_data or {}, ) @@ -441,11 +488,11 @@ class PluginRunnerSupervisor: List[str]: 已声明订阅该范围的插件 ID 列表。 """ - matched_plugins: List[str] = [] - for plugin_id, registration in self._registered_plugins.items(): - if scope in registration.config_reload_subscriptions: - matched_plugins.append(plugin_id) - return matched_plugins + return [ + plugin_id + for plugin_id, registration in self._registered_plugins.items() + if scope in registration.config_reload_subscriptions + ] async def _wait_for_runner_connection(self, timeout_sec: float) -> None: """等待 Runner 建立 RPC 连接。 @@ -706,10 +753,7 @@ class PluginRunnerSupervisor: ) gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False) - if len(gateways) == 1: - return gateways[0] - - return None + return gateways[0] if len(gateways) == 1 else None async def _register_message_gateway_driver( self, @@ -823,8 +867,7 @@ class PluginRunnerSupervisor: ValueError: 当平台信息缺失时抛出。 """ - platform = str(payload.platform or gateway_entry.platform or "").strip() - if not platform: + if not (platform := str(payload.platform or gateway_entry.platform or "").strip()): raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称") return RouteKey( @@ -1090,7 +1133,11 @@ class PluginRunnerSupervisor: Returns: Dict[str, str]: 传递给 Runner 进程的环境变量映射。 """ + global_config_snapshot = config_manager.get_global_config().model_dump() + global_config_snapshot["model"] = config_manager.get_model_config().model_dump() return { + ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False), + ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False), ENV_HOST_VERSION: PROTOCOL_VERSION, ENV_IPC_ADDRESS: self._transport.get_address(), ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs), @@ -1136,8 +1183,7 @@ class PluginRunnerSupervisor: line = await stream.readline() if not line: return - message = line.decode("utf-8", errors="replace").rstrip() - if message: + if message := line.decode("utf-8", errors="replace").rstrip(): logger.warning(f"[runner-stderr] {message}") except asyncio.CancelledError: raise diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index e45b40de..d48260e5 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -8,7 +8,7 @@ """ from pathlib import Path -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple import asyncio import json @@ -102,6 +102,77 @@ class PluginRuntimeManager( candidate = Path("plugins").resolve() return [candidate] if candidate.is_dir() else [] + @staticmethod + def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]: + """从插件 manifest 中提取规范化后的依赖插件 ID 列表。""" + + dependencies: List[str] = [] + for dependency in manifest.get("dependencies", []): + if isinstance(dependency, str): + normalized_dependency = dependency.strip() + elif isinstance(dependency, dict): + normalized_dependency = str(dependency.get("name", "") or "").strip() + else: + normalized_dependency = "" + + if normalized_dependency: + dependencies.append(normalized_dependency) + return dependencies + + @classmethod + def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]: + """扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。""" + + dependency_map: Dict[str, List[str]] = {} + for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs): + manifest_path = plugin_dir / "_manifest.json" + entrypoint_path = plugin_dir / "plugin.py" + if not manifest_path.is_file() or not entrypoint_path.is_file(): + continue + + try: + with manifest_path.open("r", encoding="utf-8") as manifest_file: + manifest = json.load(manifest_file) + except Exception: + continue + + if not isinstance(manifest, dict): + continue + + plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name + dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest) + return dependency_map + + @classmethod + def _build_group_start_order( + cls, + builtin_dirs: Sequence[Path], + third_party_dirs: Sequence[Path], + ) -> List[str]: + """根据跨 Supervisor 依赖关系决定 Runner 启动顺序。""" + + builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs) + third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs) + builtin_plugin_ids = set(builtin_dependencies) + third_party_plugin_ids = set(third_party_dependencies) + + builtin_needs_third_party = any( + dependency in third_party_plugin_ids + for dependencies in builtin_dependencies.values() + for dependency in dependencies + ) + third_party_needs_builtin = any( + dependency in builtin_plugin_ids + for dependencies in third_party_dependencies.values() + for dependency in dependencies + ) + + if builtin_needs_third_party and third_party_needs_builtin: + raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner") + if builtin_needs_third_party: + return ["third_party", "builtin"] + return ["builtin", "third_party"] + # ─── 生命周期 ───────────────────────────────────────────── async def start(self) -> None: @@ -161,12 +232,26 @@ class PluginRuntimeManager( platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) await platform_io_manager.ensure_send_pipeline_ready() - if self._builtin_supervisor: - await self._builtin_supervisor.start() - started_supervisors.append(self._builtin_supervisor) - if self._third_party_supervisor: - await self._third_party_supervisor.start() - started_supervisors.append(self._third_party_supervisor) + supervisor_groups: Dict[str, Optional[PluginSupervisor]] = { + "builtin": self._builtin_supervisor, + "third_party": self._third_party_supervisor, + } + start_order = self._build_group_start_order(builtin_dirs, third_party_dirs) + + for group_name in start_order: + supervisor = supervisor_groups.get(group_name) + if supervisor is None: + continue + + external_plugin_ids = [ + plugin_id + for started_supervisor in started_supervisors + for plugin_id in started_supervisor.get_loaded_plugin_ids() + ] + supervisor.set_external_available_plugin_ids(external_plugin_ids) + await supervisor.start() + started_supervisors.append(supervisor) + await self._start_plugin_file_watcher() config_manager.register_reload_callback(self._config_reload_callback) self._config_reload_callback_registered = True @@ -238,6 +323,171 @@ class PluginRuntimeManager( """获取所有活跃的 Supervisor""" return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None] + def _build_registered_dependency_map(self) -> Dict[str, Set[str]]: + """根据当前已注册插件构建全局依赖图。""" + + dependency_map: Dict[str, Set[str]] = {} + for supervisor in self.supervisors: + for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items(): + dependency_map[plugin_id] = { + str(dependency or "").strip() + for dependency in getattr(registration, "dependencies", []) + if str(dependency or "").strip() + } + return dependency_map + + @staticmethod + def _collect_reverse_dependents( + plugin_ids: Set[str], + dependency_map: Dict[str, Set[str]], + ) -> Set[str]: + """根据依赖图收集反向依赖闭包。""" + + impacted_plugins: Set[str] = set(plugin_ids) + changed = True + + while changed: + changed = False + for registered_plugin_id, dependencies in dependency_map.items(): + if registered_plugin_id in impacted_plugins: + continue + if dependencies & impacted_plugins: + impacted_plugins.add(registered_plugin_id) + changed = True + + return impacted_plugins + + def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]: + """构建当前已注册插件到所属 Supervisor 的映射。""" + + return { + plugin_id: supervisor + for supervisor in self.supervisors + for plugin_id in supervisor.get_loaded_plugin_ids() + } + + def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]: + """收集某个 Supervisor 可用的外部插件 ID 列表。""" + + external_plugin_ids: Set[str] = set() + for supervisor in self.supervisors: + if supervisor is target_supervisor: + continue + external_plugin_ids.update(supervisor.get_loaded_plugin_ids()) + return sorted(external_plugin_ids) + + def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]: + """根据插件目录推断应负责该插件重载的 Supervisor。""" + + for supervisor in self.supervisors: + for plugin_dir in supervisor._plugin_dirs: + if (Path(plugin_dir) / plugin_id).is_dir(): + return supervisor + return None + + def _warn_skipped_cross_supervisor_reload( + self, + requested_loaded_plugin_ids: Set[str], + dependency_map: Dict[str, Set[str]], + supervisor_by_plugin: Dict[str, "PluginSupervisor"], + ) -> None: + """记录因跨 Supervisor 边界而未参与联动重载的插件。""" + + if not requested_loaded_plugin_ids: + return + + handled_plugin_ids: Set[str] = set() + for supervisor in self.supervisors: + local_requested_plugin_ids = { + plugin_id + for plugin_id in requested_loaded_plugin_ids + if supervisor_by_plugin.get(plugin_id) is supervisor + } + if not local_requested_plugin_ids: + continue + + local_plugin_ids = set(supervisor.get_loaded_plugin_ids()) + local_dependency_map = { + plugin_id: { + dependency + for dependency in dependency_map.get(plugin_id, set()) + if dependency in local_plugin_ids + } + for plugin_id in local_plugin_ids + } + handled_plugin_ids.update( + self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map) + ) + + impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map) + skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids) + if not skipped_plugin_ids: + return + + logger.warning( + f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: " + f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;" + "跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。" + ) + + async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: + """按 Supervisor 分组执行精确重载。 + + 仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警, + 不再自动参与本次热重载。 + """ + + normalized_plugin_ids = [ + normalized_plugin_id + for plugin_id in plugin_ids + if (normalized_plugin_id := str(plugin_id or "").strip()) + ] + if not normalized_plugin_ids: + return True + + dependency_map = self._build_registered_dependency_map() + supervisor_by_plugin = self._build_registered_supervisor_map() + supervisor_roots: Dict["PluginSupervisor", List[str]] = {} + requested_loaded_plugin_ids: Set[str] = set() + missing_plugin_ids: List[str] = [] + + for plugin_id in normalized_plugin_ids: + supervisor = supervisor_by_plugin.get(plugin_id) + if supervisor is not None: + requested_loaded_plugin_ids.add(plugin_id) + else: + supervisor = self._find_supervisor_by_plugin_directory(plugin_id) + + if supervisor is None: + missing_plugin_ids.append(plugin_id) + continue + + if plugin_id not in supervisor_roots.setdefault(supervisor, []): + supervisor_roots[supervisor].append(plugin_id) + + if missing_plugin_ids: + logger.warning(f"以下插件未找到可重载的 Supervisor,已跳过: {', '.join(sorted(missing_plugin_ids))}") + + self._warn_skipped_cross_supervisor_reload( + requested_loaded_plugin_ids=requested_loaded_plugin_ids, + dependency_map=dependency_map, + supervisor_by_plugin=supervisor_by_plugin, + ) + + success = True + for supervisor, root_plugin_ids in supervisor_roots.items(): + if not root_plugin_ids: + continue + + reloaded = await supervisor.reload_plugins( + plugin_ids=root_plugin_ids, + reason=reason, + external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), + ) + success = success and reloaded + + return success and not missing_plugin_ids + async def notify_plugin_config_updated( self, plugin_id: str, @@ -465,6 +715,31 @@ class PluginRuntimeManager( raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由") return matches[0] if matches else None + async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: + """加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。""" + + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id: + return False + + try: + registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id) + except RuntimeError: + return False + + if registered_supervisor is not None: + return await self.reload_plugins_globally([normalized_plugin_id], reason=reason) + + supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id) + if supervisor is None: + return False + + return await supervisor.reload_plugins( + plugin_ids=[normalized_plugin_id], + reason=reason, + external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor), + ) + @staticmethod def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]: """扫描插件目录,找出被多个目录重复声明的插件 ID。""" @@ -729,7 +1004,7 @@ class PluginRuntimeManager( logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}") return - reload_supervisors: Dict[Any, List[str]] = {} + changed_plugin_ids: List[str] = [] changed_paths = [change.path.resolve() for change in changes] for supervisor in self.supervisors: @@ -738,14 +1013,11 @@ class PluginRuntimeManager( if plugin_id is None: continue if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py": - reload_supervisors.setdefault(supervisor, []) - if plugin_id not in reload_supervisors[supervisor]: - reload_supervisors[supervisor].append(plugin_id) + if plugin_id not in changed_plugin_ids: + changed_plugin_ids.append(plugin_id) - for supervisor, plugin_ids in reload_supervisors.items(): - await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher") - - if reload_supervisors: + if changed_plugin_ids: + await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher") self._refresh_plugin_config_watch_subscriptions() @staticmethod diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 6078e4dc..ce40d855 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -166,6 +166,8 @@ class RegisterPluginPayload(BaseModel): """组件列表""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" + dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表") + """插件级依赖插件 ID 列表""" config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围") """订阅的全局配置热重载范围""" @@ -280,6 +282,8 @@ class ReloadPluginPayload(BaseModel): """目标插件 ID""" reason: str = Field(default="manual", description="重载原因") """重载原因""" + external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID") + """可视为已满足的外部依赖插件 ID""" class ReloadPluginResultPayload(BaseModel): diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index a766eb04..f07eb593 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -95,11 +95,16 @@ class PluginLoader: self._manifest_validator = ManifestValidator(host_version=host_version) self._compat_hook_installed = False - def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]: + def discover_and_load( + self, + plugin_dirs: List[str], + extra_available: Optional[Set[str]] = None, + ) -> List[PluginMeta]: """扫描多个目录并加载所有插件。 Args: plugin_dirs: 插件目录列表。 + extra_available: 额外视为已满足的外部依赖插件 ID 集合。 Returns: List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。 @@ -108,7 +113,7 @@ class PluginLoader: self._record_duplicate_candidates(duplicate_candidates) # 第二阶段:依赖解析(拓扑排序) - load_order, failed_deps = self._resolve_dependencies(candidates) + load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available) self._record_failed_dependencies(failed_deps) # 第三阶段:按依赖顺序加载 diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index b38946d6..c0f5e771 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, ca import asyncio import contextlib import inspect +import json import logging as stdlib_logging import os import signal @@ -23,7 +24,13 @@ import time import tomllib from src.common.logger import get_console_handler, get_logger, initialize_logging -from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN +from src.plugin_runtime import ( + ENV_EXTERNAL_PLUGIN_IDS, + ENV_HOST_VERSION, + ENV_IPC_ADDRESS, + ENV_PLUGIN_DIRS, + ENV_SESSION_TOKEN, +) from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ComponentDeclaration, @@ -112,6 +119,7 @@ class PluginRunner: host_address: str, session_token: str, plugin_dirs: List[str], + external_available_plugin_ids: Optional[List[str]] = None, ) -> None: """初始化 Runner。 @@ -119,10 +127,16 @@ class PluginRunner: host_address: Host 的 IPC 地址。 session_token: 握手用会话令牌。 plugin_dirs: 当前 Runner 负责扫描的插件目录列表。 + external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。 """ self._host_address: str = host_address self._session_token: str = session_token self._plugin_dirs: List[str] = plugin_dirs + self._external_available_plugin_ids: Set[str] = { + str(plugin_id or "").strip() + for plugin_id in (external_available_plugin_ids or []) + if str(plugin_id or "").strip() + } self._rpc_client: RPCClient = RPCClient(host_address, session_token) self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, "")) @@ -150,7 +164,10 @@ class PluginRunner: self._register_handlers() # 3. 加载插件 - plugins = self._loader.discover_and_load(self._plugin_dirs) + plugins = self._loader.discover_and_load( + self._plugin_dirs, + extra_available=self._external_available_plugin_ids, + ) logger.info(f"已加载 {len(plugins)} 个插件") # 4. 注入 PluginContext + 调用 on_load 生命周期钩子 @@ -379,6 +396,7 @@ class PluginRunner: plugin_version=meta.version, components=components, capabilities_required=meta.capabilities_required, + dependencies=meta.dependencies, config_reload_subscriptions=config_reload_subscriptions, ) @@ -485,18 +503,20 @@ class PluginRunner: self._loader.set_loaded_plugin(meta) return True - async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None: + async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None: """卸载单个插件并清理 Host/Runner 两侧状态。 Args: meta: 待卸载的插件元数据。 reason: 卸载原因。 + purge_modules: 是否在卸载完成后清理插件模块缓存。 """ await self._invoke_plugin_on_unload(meta) await self._unregister_plugin(meta.plugin_id, reason) await self._deactivate_plugin(meta) self._loader.remove_loaded_plugin(meta.plugin_id) - self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + if purge_modules: + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]: """收集依赖指定插件的所有已加载插件。 @@ -564,18 +584,52 @@ class PluginRunner: return list(reversed(load_order)) - async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload: + @staticmethod + def _finalize_failed_reload_messages( + failed_plugins: Dict[str, str], + rollback_failures: Dict[str, str], + ) -> Dict[str, str]: + """在重载失败后补充回滚结果说明。""" + + finalized_failures: Dict[str, str] = {} + for failed_plugin_id, failure_reason in failed_plugins.items(): + rollback_failure = rollback_failures.get(failed_plugin_id) + if rollback_failure: + finalized_failures[failed_plugin_id] = ( + f"{failure_reason};且旧版本恢复失败: {rollback_failure}" + ) + else: + finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)" + + for failed_plugin_id, rollback_failure in rollback_failures.items(): + if failed_plugin_id not in finalized_failures: + finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}" + + return finalized_failures + + async def _reload_plugin_by_id( + self, + plugin_id: str, + reason: str, + external_available_plugins: Optional[Set[str]] = None, + ) -> ReloadPluginResultPayload: """按插件 ID 在 Runner 进程内执行精确重载。 Args: plugin_id: 目标插件 ID。 reason: 重载原因。 + external_available_plugins: 视为已满足的外部依赖插件 ID 集合。 Returns: ReloadPluginResultPayload: 结构化重载结果。 """ candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) failed_plugins: Dict[str, str] = {} + normalized_external_available = { + str(candidate_plugin_id or "").strip() + for candidate_plugin_id in (external_available_plugins or set()) + if str(candidate_plugin_id or "").strip() + } if plugin_id in duplicate_candidates: conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) @@ -603,29 +657,32 @@ class PluginRunner: unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) unloaded_plugins: List[str] = [] retained_plugin_ids = loaded_plugin_ids - set(unload_order) + rollback_metas: Dict[str, PluginMeta] = {} for unload_plugin_id in unload_order: meta = self._loader.get_plugin(unload_plugin_id) if meta is None: continue - await self._unload_plugin(meta, reason=reason) + rollback_metas[unload_plugin_id] = meta + await self._unload_plugin(meta, reason=reason, purge_modules=False) + self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir) unloaded_plugins.append(unload_plugin_id) reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {} for target_plugin_id in target_plugin_ids: candidate = candidates.get(target_plugin_id) if candidate is None: - failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态" + failed_plugins[target_plugin_id] = "插件目录已不存在" continue reload_candidates[target_plugin_id] = candidate load_order, dependency_failures = self._loader.resolve_dependencies( reload_candidates, - extra_available=retained_plugin_ids, + extra_available=retained_plugin_ids | normalized_external_available, ) failed_plugins.update(dependency_failures) - available_plugins = set(retained_plugin_ids) + available_plugins = set(retained_plugin_ids) | normalized_external_available reloaded_plugins: List[str] = [] for load_plugin_id in load_order: @@ -656,7 +713,48 @@ class PluginRunner: available_plugins.add(load_plugin_id) reloaded_plugins.append(load_plugin_id) - requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins + if failed_plugins: + rollback_failures: Dict[str, str] = {} + + for reloaded_plugin_id in reversed(reloaded_plugins): + reloaded_meta = self._loader.get_plugin(reloaded_plugin_id) + if reloaded_meta is None: + continue + + try: + await self._unload_plugin( + reloaded_meta, + reason=f"{reason}_rollback_cleanup", + purge_modules=False, + ) + except Exception as exc: + rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}" + finally: + self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir) + + for rollback_plugin_id in reversed(unload_order): + rollback_meta = rollback_metas.get(rollback_plugin_id) + if rollback_meta is None: + continue + + try: + restored = await self._activate_plugin(rollback_meta) + except Exception as exc: + rollback_failures[rollback_plugin_id] = str(exc) + continue + + if not restored: + rollback_failures[rollback_plugin_id] = "无法重新激活旧版本" + + return ReloadPluginResultPayload( + success=False, + requested_plugin_id=plugin_id, + reloaded_plugins=[], + unloaded_plugins=unloaded_plugins, + failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures), + ) + + requested_plugin_success = plugin_id in reloaded_plugins return ReloadPluginResultPayload( success=requested_plugin_success, @@ -978,7 +1076,11 @@ class PluginRunner: ) async with self._reload_lock: - result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason) + result = await self._reload_plugin_by_id( + payload.plugin_id, + payload.reason, + external_available_plugins=set(payload.external_available_plugins), + ) return envelope.make_response(payload=result.model_dump()) def request_capability(self) -> RPCClient: @@ -1073,6 +1175,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: async def _async_main() -> None: """异步主入口""" host_address = os.environ.get(ENV_IPC_ADDRESS, "") + external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "") session_token = os.environ.get(ENV_SESSION_TOKEN, "") plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "") @@ -1081,11 +1184,24 @@ async def _async_main() -> None: sys.exit(1) plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] + try: + external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else [] + except json.JSONDecodeError: + logger.warning("解析外部依赖插件列表失败,已回退为空列表") + external_plugin_ids = [] + if not isinstance(external_plugin_ids, list): + logger.warning("外部依赖插件列表格式非法,已回退为空列表") + external_plugin_ids = [] # sys.path 隔离: 只保留标准库、SDK 包、插件目录 _isolate_sys_path(plugin_dirs) - runner = PluginRunner(host_address, session_token, plugin_dirs) + runner = PluginRunner( + host_address, + session_token, + plugin_dirs, + external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids], + ) # 注册信号处理 def _mark_runner_shutting_down() -> None: