feat: enhance session ID calculation and plugin management

- Updated `calculate_session_id` method in `SessionUtils` to include optional `account_id` and `scope` parameters for more granular session ID generation.
- Added new environment variables in `plugin_runtime` for external plugin dependencies and global configuration snapshots.
- Introduced methods in `RuntimeComponentManagerProtocol` for loading and reloading plugins globally, accommodating external dependencies.
- Enhanced `PluginRunnerSupervisor` to manage external available plugin IDs during plugin reloads.
- Implemented dependency extraction and management in `PluginRuntimeManager` to handle cross-supervisor dependencies.
- Added tests for session ID calculation and message registration in `ChatManager` to ensure correct behavior with new parameters.
This commit is contained in:
DrSmoothl
2026-03-23 21:48:19 +08:00
parent 7a304ba549
commit 0c508995dd
12 changed files with 765 additions and 170 deletions

View File

@@ -3,6 +3,7 @@
验证协议层、传输层、RPC 通信链路的正确性。
"""
from pathlib import Path
from types import SimpleNamespace
import asyncio
@@ -2362,6 +2363,8 @@ class TestIntegration:
from src.plugin_runtime import integration as integration_module
instances = []
builtin_dir = Path("builtin")
thirdparty_dir = Path("thirdparty")
class FakeCapabilityService:
def register_capability(self, name, impl):
@@ -2369,11 +2372,18 @@ class TestIntegration:
class FakeSupervisor:
def __init__(self, plugin_dirs=None, socket_path=None):
self.plugin_dirs = plugin_dirs or []
self._plugin_dirs = plugin_dirs or []
self.capability_service = FakeCapabilityService()
self.external_plugin_ids = []
self.stopped = False
instances.append(self)
def set_external_available_plugin_ids(self, plugin_ids):
self.external_plugin_ids = list(plugin_ids)
def get_loaded_plugin_ids(self):
return []
async def start(self):
if len(instances) == 2 and self is instances[1]:
raise RuntimeError("boom")
@@ -2382,10 +2392,10 @@ class TestIntegration:
self.stopped = True
monkeypatch.setattr(
integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir])
)
monkeypatch.setattr(
integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir])
)
import src.plugin_runtime.host.supervisor as supervisor_module
@@ -2427,8 +2437,11 @@ class TestIntegration:
self.reload_reasons = []
self.config_updates = []
async def reload_plugins(self, plugin_ids=None, reason="manual"):
self.reload_reasons.append((plugin_ids, reason))
def get_loaded_plugin_ids(self):
return sorted(self._registered_plugins.keys())
async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
self.reload_reasons.append((plugin_ids, reason, external_available_plugins or []))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version))
@@ -2453,11 +2466,59 @@ class TestIntegration:
await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == []
assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher", ["alpha"])]
assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True]
@pytest.mark.asyncio
async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
class FakeRegistration:
def __init__(self, dependencies):
self.dependencies = dependencies
class FakeSupervisor:
def __init__(self, registrations):
self._registered_plugins = registrations
self.reload_calls = []
def get_loaded_plugin_ids(self):
return sorted(self._registered_plugins.keys())
async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None):
self.reload_calls.append((plugin_ids, reason, sorted(external_available_plugins or [])))
return True
builtin_supervisor = FakeSupervisor({"alpha": FakeRegistration([])})
third_party_supervisor = FakeSupervisor(
{
"beta": FakeRegistration(["alpha"]),
"gamma": FakeRegistration(["beta"]),
}
)
manager = integration_module.PluginRuntimeManager()
manager._builtin_supervisor = builtin_supervisor
manager._third_party_supervisor = third_party_supervisor
warning_messages = []
monkeypatch.setattr(
integration_module.logger,
"warning",
lambda message: warning_messages.append(message),
)
reloaded = await manager.reload_plugins_globally(["alpha"], reason="manual")
assert reloaded is True
assert builtin_supervisor.reload_calls == [(["alpha"], "manual", ["beta", "gamma"])]
assert third_party_supervisor.reload_calls == []
assert len(warning_messages) == 1
assert "beta, gamma" in warning_messages[0]
assert "跨 Supervisor API 调用仍然可用" in warning_messages[0]
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
@@ -2623,55 +2684,30 @@ class TestIntegration:
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
class FakeSupervisor:
def __init__(self):
self._registered_plugins = {"alpha": object()}
manager = integration_module.PluginRuntimeManager()
monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False))
async def reload_plugins(self, reason="manual"):
return False
class FakeManager:
def __init__(self):
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin(
result = await manager._cap_component_reload_plugin(
"plugin_a",
"component.reload_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
assert "已回滚" in result["error"]
assert result["error"] == "插件 alpha 热重载失败"
@pytest.mark.asyncio
async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
(plugin_root / "alpha").mkdir()
manager = integration_module.PluginRuntimeManager()
monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False))
class FakeSupervisor:
def __init__(self):
self._registered_plugins = {}
self._plugin_dirs = [str(plugin_root)]
async def reload_plugins(self, reason="manual"):
return False
class FakeManager:
def __init__(self):
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
result = await integration_module.PluginRuntimeManager._cap_component_load_plugin(
result = await manager._cap_component_load_plugin(
"plugin_a",
"component.load_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
assert "已回滚" in result["error"]
assert result["error"] == "插件 alpha 热重载失败"

View File

@@ -0,0 +1,42 @@
from types import SimpleNamespace
from src.chat.message_receive.chat_manager import ChatManager
from src.common.utils.utils_session import SessionUtils
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
assert base_session_id == same_base_session_id
assert account_scoped_session_id != base_session_id
assert route_scoped_session_id != account_scoped_session_id
def test_chat_manager_register_message_uses_route_metadata() -> None:
chat_manager = ChatManager()
message = SimpleNamespace(
platform="qq",
session_id="",
message_info=SimpleNamespace(
user_info=SimpleNamespace(user_id="42"),
group_info=SimpleNamespace(group_id="1000"),
additional_config={
"platform_io_account_id": "123",
"platform_io_scope": "main",
},
),
)
chat_manager.register_message(message)
assert message.session_id == SessionUtils.calculate_session_id(
"qq",
user_id="42",
group_id="1000",
account_id="123",
scope="main",
)
assert chat_manager.last_messages[message.session_id] is message

View File

@@ -10,6 +10,7 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
from src.platform_io.route_key_factory import RouteKeyFactory
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.core.announcement_manager import global_announcement_manager
@@ -270,11 +271,18 @@ class ChatBot:
try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
account_id = None
scope = None
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
session_id = SessionUtils.calculate_session_id(
message.platform,
user_id=message.message_info.user_info.user_id,
group_id=group_info.group_id if group_info else None,
account_id=account_id,
scope=scope,
)
message.session_id = session_id # 正确初始化session_id
@@ -317,7 +325,13 @@ class ChatBot:
platform = message.platform
user_id = user_info.user_id
group_id = group_info.group_id if group_info else None
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
_ = await chat_manager.get_or_create_session(
platform,
user_id,
group_id,
account_id=account_id,
scope=scope,
) # 确保会话存在
# message.update_chat_stream(chat)

View File

@@ -1,15 +1,16 @@
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List, Optional
from rich.traceback import install
from sqlmodel import select
from typing import Optional, TYPE_CHECKING, List, Dict
import asyncio
from src.common.logger import get_logger
from src.common.data_models.chat_session_data_model import MaiChatSession
from src.common.database.database_model import ChatSession
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession
from src.common.logger import get_logger
from src.common.utils.utils_session import SessionUtils
from src.platform_io.route_key_factory import RouteKeyFactory
if TYPE_CHECKING:
from .message import SessionMessage
@@ -82,7 +83,12 @@ class ChatManager:
logger.error(f"初始化聊天管理器出现错误: {e}")
async def get_or_create_session(
self, platform: str, user_id: str, group_id: Optional[str] = None
self,
platform: str,
user_id: str,
group_id: Optional[str] = None,
account_id: Optional[str] = None,
scope: Optional[str] = None,
) -> BotChatSession:
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
@@ -90,12 +96,20 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID如果是群聊
account_id: 平台账号 ID
scope: 路由作用域
Returns:
return (BotChatSession) 会话对象
Raises:
Exception: 获取或创建会话时发生错误
"""
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(
platform,
user_id=user_id,
group_id=group_id,
account_id=account_id,
scope=scope,
)
if session := self.get_session_by_session_id(session_id):
session.update_active_time()
return session
@@ -131,7 +145,18 @@ class ChatManager:
raise ValueError("消息缺少平台信息")
user_id = message.message_info.user_info.user_id
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
account_id = None
scope = None
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
session_id = SessionUtils.calculate_session_id(
platform,
user_id=user_id,
group_id=group_id,
account_id=account_id,
scope=scope,
)
message.session_id = session_id # 确保消息的session_id正确设置
self.last_messages[session_id] = message
@@ -188,7 +213,12 @@ class ChatManager:
return None
def get_session_by_info(
self, platform: str, user_id: Optional[str] = None, group_id: Optional[str] = None
self,
platform: str,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
account_id: Optional[str] = None,
scope: Optional[str] = None,
) -> Optional[BotChatSession]:
"""根据平台、用户ID和群ID获取对应的会话
@@ -196,10 +226,18 @@ class ChatManager:
platform: 平台
user_id: 用户ID
group_id: 群ID如果是群聊
account_id: 平台账号 ID
scope: 路由作用域
Returns:
return (Optional[BotChatSession]): 会话对象如果不存在则返回None
"""
session_id = SessionUtils.calculate_session_id(platform, user_id=user_id, group_id=group_id)
session_id = SessionUtils.calculate_session_id(
platform,
user_id=user_id,
group_id=group_id,
account_id=account_id,
scope=scope,
)
return self.get_session_by_session_id(session_id)
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:

View File

@@ -5,13 +5,22 @@ import hashlib
class SessionUtils:
@staticmethod
def calculate_session_id(platform: str, *, user_id: Optional[str] = None, group_id: Optional[str] = None) -> str:
def calculate_session_id(
platform: str,
*,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
account_id: Optional[str] = None,
scope: Optional[str] = None,
) -> str:
"""计算session_id
Args:
platform: 平台名称
user_id: 用户ID如果是私聊
group_id: 群ID如果是群聊
account_id: 当前平台账号 ID可选
scope: 当前路由作用域,可选
Returns:
str: 计算得到的会话ID
Raises:
@@ -19,8 +28,15 @@ class SessionUtils:
"""
if not user_id and not group_id:
raise ValueError("UserID 或 GroupID 必须提供其一")
route_components = []
if account_id:
route_components.append(f"account:{account_id}")
if scope:
route_components.append(f"scope:{scope}")
if group_id:
components = [platform, group_id]
components = [platform, *route_components, group_id]
else:
components = [platform, user_id, "private"]
components = [platform, *route_components, user_id, "private"]
return hashlib.md5("_".join(components).encode()).hexdigest()

View File

@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
"""Runner 启动时可视为已满足的外部插件依赖列表JSON 数组)"""
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
"""Runner 启动时注入的全局配置快照JSON 对象)"""

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
from src.common.logger import get_logger
@@ -15,8 +15,35 @@ class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
def _normalize_component_type(self, component_type: str) -> str: ...
def _is_api_component_type(self, component_type: str) -> bool: ...
def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
def _resolve_api_target(
self,
caller_plugin_id: str,
api_name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
def _resolve_api_toggle_target(
self,
name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
@@ -25,6 +52,10 @@ class _RuntimeComponentManagerProtocol(Protocol):
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
class RuntimeComponentCapabilityMixin:
@staticmethod
@@ -266,20 +297,22 @@ class RuntimeComponentCapabilityMixin:
version=normalized_version,
enabled_only=False,
)
if not entries:
return None, None, f"未找到 API: {normalized_name}"
if len(entries) > 1:
if len(entries) == 1:
return supervisor, entries[0], None
if entries:
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
return supervisor, entries[0], None
return None, None, f"未找到 API: {normalized_name}"
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(
name=normalized_name,
version=normalized_version,
enabled_only=False,
):
matches.append((supervisor, entry))
matches.extend(
(supervisor, entry)
for entry in supervisor.api_registry.get_apis(
name=normalized_name,
version=normalized_version,
enabled_only=False,
)
)
if len(matches) == 1:
return matches[0][0], matches[0][1], None
@@ -453,39 +486,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
if registered_supervisor is not None:
try:
reloaded = await registered_supervisor.reload_plugins(
plugin_ids=[plugin_name],
reason=f"load {plugin_name}",
)
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
for sv in self.supervisors:
for pdir in sv._plugin_dirs:
if (pdir / plugin_name).is_dir():
try:
reloaded = await sv.reload_plugins(
plugin_ids=[plugin_name],
reason=f"load {plugin_name}",
)
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
if loaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
@@ -507,23 +515,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
if sv is not None:
try:
reloaded = await sv.reload_plugins(
plugin_ids=[plugin_name],
reason=f"reload {plugin_name}",
)
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_api_call(
self: _RuntimeComponentManagerProtocol,
@@ -632,15 +631,16 @@ class RuntimeComponentCapabilityMixin:
)
apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None,
name=api_name,
version=version,
enabled_only=True,
):
if not self._is_api_visible_to_plugin(entry, plugin_id):
continue
apis.append(self._serialize_api_entry(entry))
apis.extend(
self._serialize_api_entry(entry)
for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None,
name=api_name,
version=version,
enabled_only=True,
)
if self._is_api_visible_to_plugin(entry, plugin_id)
)
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis}

View File

@@ -4,17 +4,26 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import asyncio
import contextlib
import json
import os
import sys
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import config_manager, global_config
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime import (
ENV_EXTERNAL_PLUGIN_IDS,
ENV_GLOBAL_CONFIG_SNAPSHOT,
ENV_HOST_VERSION,
ENV_IPC_ADDRESS,
ENV_PLUGIN_DIRS,
ENV_SESSION_TOKEN,
)
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ConfigReloadScope,
ConfigUpdatedPayload,
Envelope,
HealthPayload,
@@ -107,6 +116,7 @@ class PluginRunnerSupervisor:
self._runner_process: Optional[asyncio.subprocess.Process] = None
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
self._message_gateway_states: Dict[str, Dict[str, _MessageGatewayRuntimeState]] = {}
self._external_available_plugin_ids: List[str] = []
self._runner_ready_events: asyncio.Event = asyncio.Event()
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
self._health_task: Optional[asyncio.Task[None]] = None
@@ -156,6 +166,21 @@ class PluginRunnerSupervisor:
"""返回底层 RPC 服务端。"""
return self._rpc_server
def set_external_available_plugin_ids(self, plugin_ids: List[str]) -> None:
"""设置当前 Runner 启动/重载时可视为已满足的外部依赖列表。"""
normalized_plugin_ids = {
str(plugin_id or "").strip()
for plugin_id in plugin_ids
if str(plugin_id or "").strip()
}
self._external_available_plugin_ids = sorted(normalized_plugin_ids)
def get_loaded_plugin_ids(self) -> List[str]:
"""返回当前 Supervisor 已注册的插件 ID 列表。"""
return sorted(self._registered_plugins.keys())
async def dispatch_event(
self,
event_type: str,
@@ -344,12 +369,18 @@ class PluginRunnerSupervisor:
timeout_ms=timeout_ms,
)
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
async def reload_plugin(
self,
plugin_id: str,
reason: str = "manual",
external_available_plugins: Optional[List[str]] = None,
) -> bool:
"""按插件 ID 触发精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
Returns:
bool: 是否重载成功。
@@ -358,7 +389,11 @@ class PluginRunnerSupervisor:
response = await self._rpc_server.send_request(
"plugin.reload",
plugin_id=plugin_id,
payload={"plugin_id": plugin_id, "reason": reason},
payload={
"plugin_id": plugin_id,
"reason": reason,
"external_available_plugins": external_available_plugins or self._external_available_plugin_ids,
},
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
)
except Exception as exc:
@@ -374,12 +409,14 @@ class PluginRunnerSupervisor:
self,
plugin_ids: Optional[List[str]] = None,
reason: str = "manual",
external_available_plugins: Optional[List[str]] = None,
) -> bool:
"""批量重载插件。
Args:
plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。
reason: 重载原因。
external_available_plugins: 视为已满足的外部依赖插件 ID 列表。
Returns:
bool: 是否全部重载成功。
@@ -389,7 +426,11 @@ class PluginRunnerSupervisor:
success = True
for plugin_id in ordered_plugin_ids:
reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason)
reloaded = await self.reload_plugin(
plugin_id=plugin_id,
reason=reason,
external_available_plugins=external_available_plugins,
)
success = success and reloaded
return success
@@ -399,7 +440,7 @@ class PluginRunnerSupervisor:
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
config_scope: str = "self",
config_scope: str | ConfigReloadScope = "self",
) -> bool:
"""向 Runner 推送插件配置更新。
@@ -412,9 +453,15 @@ class PluginRunnerSupervisor:
Returns:
bool: 请求是否成功送达并被 Runner 接受。
"""
try:
normalized_scope = ConfigReloadScope(config_scope)
except ValueError:
logger.warning(f"插件 {plugin_id} 配置更新通知失败: 非法的 config_scope={config_scope}")
return False
payload = ConfigUpdatedPayload(
plugin_id=plugin_id,
config_scope=config_scope,
config_scope=normalized_scope,
config_version=config_version,
config_data=config_data or {},
)
@@ -441,11 +488,11 @@ class PluginRunnerSupervisor:
List[str]: 已声明订阅该范围的插件 ID 列表。
"""
matched_plugins: List[str] = []
for plugin_id, registration in self._registered_plugins.items():
if scope in registration.config_reload_subscriptions:
matched_plugins.append(plugin_id)
return matched_plugins
return [
plugin_id
for plugin_id, registration in self._registered_plugins.items()
if scope in registration.config_reload_subscriptions
]
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
"""等待 Runner 建立 RPC 连接。
@@ -706,10 +753,7 @@ class PluginRunnerSupervisor:
)
gateways = self._component_registry.get_message_gateways(plugin_id=plugin_id, enabled_only=False)
if len(gateways) == 1:
return gateways[0]
return None
return gateways[0] if len(gateways) == 1 else None
async def _register_message_gateway_driver(
self,
@@ -823,8 +867,7 @@ class PluginRunnerSupervisor:
ValueError: 当平台信息缺失时抛出。
"""
platform = str(payload.platform or gateway_entry.platform or "").strip()
if not platform:
if not (platform := str(payload.platform or gateway_entry.platform or "").strip()):
raise ValueError(f"消息网关 {gateway_entry.full_name} 未提供有效的平台名称")
return RouteKey(
@@ -1090,7 +1133,11 @@ class PluginRunnerSupervisor:
Returns:
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
"""
global_config_snapshot = config_manager.get_global_config().model_dump()
global_config_snapshot["model"] = config_manager.get_model_config().model_dump()
return {
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugin_ids, ensure_ascii=False),
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
ENV_HOST_VERSION: PROTOCOL_VERSION,
ENV_IPC_ADDRESS: self._transport.get_address(),
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
@@ -1136,8 +1183,7 @@ class PluginRunnerSupervisor:
line = await stream.readline()
if not line:
return
message = line.decode("utf-8", errors="replace").rstrip()
if message:
if message := line.decode("utf-8", errors="replace").rstrip():
logger.warning(f"[runner-stderr] {message}")
except asyncio.CancelledError:
raise

View File

@@ -8,7 +8,7 @@
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import asyncio
import json
@@ -102,6 +102,77 @@ class PluginRuntimeManager(
candidate = Path("plugins").resolve()
return [candidate] if candidate.is_dir() else []
@staticmethod
def _extract_manifest_dependencies(manifest: Dict[str, Any]) -> List[str]:
"""从插件 manifest 中提取规范化后的依赖插件 ID 列表。"""
dependencies: List[str] = []
for dependency in manifest.get("dependencies", []):
if isinstance(dependency, str):
normalized_dependency = dependency.strip()
elif isinstance(dependency, dict):
normalized_dependency = str(dependency.get("name", "") or "").strip()
else:
normalized_dependency = ""
if normalized_dependency:
dependencies.append(normalized_dependency)
return dependencies
@classmethod
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
dependency_map: Dict[str, List[str]] = {}
for plugin_dir in cls._iter_candidate_plugin_paths(plugin_dirs):
manifest_path = plugin_dir / "_manifest.json"
entrypoint_path = plugin_dir / "plugin.py"
if not manifest_path.is_file() or not entrypoint_path.is_file():
continue
try:
with manifest_path.open("r", encoding="utf-8") as manifest_file:
manifest = json.load(manifest_file)
except Exception:
continue
if not isinstance(manifest, dict):
continue
plugin_id = str(manifest.get("name", plugin_dir.name) or "").strip() or plugin_dir.name
dependency_map[plugin_id] = cls._extract_manifest_dependencies(manifest)
return dependency_map
@classmethod
def _build_group_start_order(
cls,
builtin_dirs: Sequence[Path],
third_party_dirs: Sequence[Path],
) -> List[str]:
"""根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
builtin_plugin_ids = set(builtin_dependencies)
third_party_plugin_ids = set(third_party_dependencies)
builtin_needs_third_party = any(
dependency in third_party_plugin_ids
for dependencies in builtin_dependencies.values()
for dependency in dependencies
)
third_party_needs_builtin = any(
dependency in builtin_plugin_ids
for dependencies in third_party_dependencies.values()
for dependency in dependencies
)
if builtin_needs_third_party and third_party_needs_builtin:
raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
if builtin_needs_third_party:
return ["third_party", "builtin"]
return ["builtin", "third_party"]
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -161,12 +232,26 @@ class PluginRuntimeManager(
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
await platform_io_manager.ensure_send_pipeline_ready()
if self._builtin_supervisor:
await self._builtin_supervisor.start()
started_supervisors.append(self._builtin_supervisor)
if self._third_party_supervisor:
await self._third_party_supervisor.start()
started_supervisors.append(self._third_party_supervisor)
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
"builtin": self._builtin_supervisor,
"third_party": self._third_party_supervisor,
}
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
for group_name in start_order:
supervisor = supervisor_groups.get(group_name)
if supervisor is None:
continue
external_plugin_ids = [
plugin_id
for started_supervisor in started_supervisors
for plugin_id in started_supervisor.get_loaded_plugin_ids()
]
supervisor.set_external_available_plugin_ids(external_plugin_ids)
await supervisor.start()
started_supervisors.append(supervisor)
await self._start_plugin_file_watcher()
config_manager.register_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = True
@@ -238,6 +323,171 @@ class PluginRuntimeManager(
"""获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。"""
dependency_map: Dict[str, Set[str]] = {}
for supervisor in self.supervisors:
for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
dependency_map[plugin_id] = {
str(dependency or "").strip()
for dependency in getattr(registration, "dependencies", [])
if str(dependency or "").strip()
}
return dependency_map
@staticmethod
def _collect_reverse_dependents(
plugin_ids: Set[str],
dependency_map: Dict[str, Set[str]],
) -> Set[str]:
"""根据依赖图收集反向依赖闭包。"""
impacted_plugins: Set[str] = set(plugin_ids)
changed = True
while changed:
changed = False
for registered_plugin_id, dependencies in dependency_map.items():
if registered_plugin_id in impacted_plugins:
continue
if dependencies & impacted_plugins:
impacted_plugins.add(registered_plugin_id)
changed = True
return impacted_plugins
def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
"""构建当前已注册插件到所属 Supervisor 的映射。"""
return {
plugin_id: supervisor
for supervisor in self.supervisors
for plugin_id in supervisor.get_loaded_plugin_ids()
}
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> List[str]:
"""收集某个 Supervisor 可用的外部插件 ID 列表。"""
external_plugin_ids: Set[str] = set()
for supervisor in self.supervisors:
if supervisor is target_supervisor:
continue
external_plugin_ids.update(supervisor.get_loaded_plugin_ids())
return sorted(external_plugin_ids)
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
for supervisor in self.supervisors:
for plugin_dir in supervisor._plugin_dirs:
if (Path(plugin_dir) / plugin_id).is_dir():
return supervisor
return None
def _warn_skipped_cross_supervisor_reload(
self,
requested_loaded_plugin_ids: Set[str],
dependency_map: Dict[str, Set[str]],
supervisor_by_plugin: Dict[str, "PluginSupervisor"],
) -> None:
"""记录因跨 Supervisor 边界而未参与联动重载的插件。"""
if not requested_loaded_plugin_ids:
return
handled_plugin_ids: Set[str] = set()
for supervisor in self.supervisors:
local_requested_plugin_ids = {
plugin_id
for plugin_id in requested_loaded_plugin_ids
if supervisor_by_plugin.get(plugin_id) is supervisor
}
if not local_requested_plugin_ids:
continue
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
local_dependency_map = {
plugin_id: {
dependency
for dependency in dependency_map.get(plugin_id, set())
if dependency in local_plugin_ids
}
for plugin_id in local_plugin_ids
}
handled_plugin_ids.update(
self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
)
impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
if not skipped_plugin_ids:
return
logger.warning(
f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
"跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
)
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
"""按 Supervisor 分组执行精确重载。
仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
不再自动参与本次热重载。
"""
normalized_plugin_ids = [
normalized_plugin_id
for plugin_id in plugin_ids
if (normalized_plugin_id := str(plugin_id or "").strip())
]
if not normalized_plugin_ids:
return True
dependency_map = self._build_registered_dependency_map()
supervisor_by_plugin = self._build_registered_supervisor_map()
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
requested_loaded_plugin_ids: Set[str] = set()
missing_plugin_ids: List[str] = []
for plugin_id in normalized_plugin_ids:
supervisor = supervisor_by_plugin.get(plugin_id)
if supervisor is not None:
requested_loaded_plugin_ids.add(plugin_id)
else:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
missing_plugin_ids.append(plugin_id)
continue
if plugin_id not in supervisor_roots.setdefault(supervisor, []):
supervisor_roots[supervisor].append(plugin_id)
if missing_plugin_ids:
logger.warning(f"以下插件未找到可重载的 Supervisor已跳过: {', '.join(sorted(missing_plugin_ids))}")
self._warn_skipped_cross_supervisor_reload(
requested_loaded_plugin_ids=requested_loaded_plugin_ids,
dependency_map=dependency_map,
supervisor_by_plugin=supervisor_by_plugin,
)
success = True
for supervisor, root_plugin_ids in supervisor_roots.items():
if not root_plugin_ids:
continue
reloaded = await supervisor.reload_plugins(
plugin_ids=root_plugin_ids,
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
success = success and reloaded
return success and not missing_plugin_ids
async def notify_plugin_config_updated(
self,
plugin_id: str,
@@ -465,6 +715,31 @@ class PluginRuntimeManager(
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
return matches[0] if matches else None
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return False
try:
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
except RuntimeError:
return False
if registered_supervisor is not None:
return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
if supervisor is None:
return False
return await supervisor.reload_plugins(
plugin_ids=[normalized_plugin_id],
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
@staticmethod
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
@@ -729,7 +1004,7 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID跳过本次插件热重载: {details}")
return
reload_supervisors: Dict[Any, List[str]] = {}
changed_plugin_ids: List[str] = []
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors:
@@ -738,14 +1013,11 @@ class PluginRuntimeManager(
if plugin_id is None:
continue
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
reload_supervisors.setdefault(supervisor, [])
if plugin_id not in reload_supervisors[supervisor]:
reload_supervisors[supervisor].append(plugin_id)
if plugin_id not in changed_plugin_ids:
changed_plugin_ids.append(plugin_id)
for supervisor, plugin_ids in reload_supervisors.items():
await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher")
if reload_supervisors:
if changed_plugin_ids:
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
self._refresh_plugin_config_watch_subscriptions()
@staticmethod

View File

@@ -166,6 +166,8 @@ class RegisterPluginPayload(BaseModel):
"""组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
"""插件级依赖插件 ID 列表"""
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
"""订阅的全局配置热重载范围"""
@@ -280,6 +282,8 @@ class ReloadPluginPayload(BaseModel):
"""目标插件 ID"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
external_available_plugins: List[str] = Field(default_factory=list, description="可视为已满足的外部依赖插件 ID")
"""可视为已满足的外部依赖插件 ID"""
class ReloadPluginResultPayload(BaseModel):

View File

@@ -95,11 +95,16 @@ class PluginLoader:
self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
def discover_and_load(
self,
plugin_dirs: List[str],
extra_available: Optional[Set[str]] = None,
) -> List[PluginMeta]:
"""扫描多个目录并加载所有插件。
Args:
plugin_dirs: 插件目录列表。
extra_available: 额外视为已满足的外部依赖插件 ID 集合。
Returns:
List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。
@@ -108,7 +113,7 @@ class PluginLoader:
self._record_duplicate_candidates(duplicate_candidates)
# 第二阶段:依赖解析(拓扑排序)
load_order, failed_deps = self._resolve_dependencies(candidates)
load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available)
self._record_failed_dependencies(failed_deps)
# 第三阶段:按依赖顺序加载

View File

@@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, ca
import asyncio
import contextlib
import inspect
import json
import logging as stdlib_logging
import os
import signal
@@ -23,7 +24,13 @@ import time
import tomllib
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
from src.plugin_runtime import (
ENV_EXTERNAL_PLUGIN_IDS,
ENV_HOST_VERSION,
ENV_IPC_ADDRESS,
ENV_PLUGIN_DIRS,
ENV_SESSION_TOKEN,
)
from src.plugin_runtime.protocol.envelope import (
BootstrapPluginPayload,
ComponentDeclaration,
@@ -112,6 +119,7 @@ class PluginRunner:
host_address: str,
session_token: str,
plugin_dirs: List[str],
external_available_plugin_ids: Optional[List[str]] = None,
) -> None:
"""初始化 Runner。
@@ -119,10 +127,16 @@ class PluginRunner:
host_address: Host 的 IPC 地址。
session_token: 握手用会话令牌。
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
external_available_plugin_ids: 视为已满足的外部依赖插件 ID 列表。
"""
self._host_address: str = host_address
self._session_token: str = session_token
self._plugin_dirs: List[str] = plugin_dirs
self._external_available_plugin_ids: Set[str] = {
str(plugin_id or "").strip()
for plugin_id in (external_available_plugin_ids or [])
if str(plugin_id or "").strip()
}
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
@@ -150,7 +164,10 @@ class PluginRunner:
self._register_handlers()
# 3. 加载插件
plugins = self._loader.discover_and_load(self._plugin_dirs)
plugins = self._loader.discover_and_load(
self._plugin_dirs,
extra_available=self._external_available_plugin_ids,
)
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
@@ -379,6 +396,7 @@ class PluginRunner:
plugin_version=meta.version,
components=components,
capabilities_required=meta.capabilities_required,
dependencies=meta.dependencies,
config_reload_subscriptions=config_reload_subscriptions,
)
@@ -485,18 +503,20 @@ class PluginRunner:
self._loader.set_loaded_plugin(meta)
return True
async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None:
async def _unload_plugin(self, meta: PluginMeta, reason: str, *, purge_modules: bool = True) -> None:
"""卸载单个插件并清理 Host/Runner 两侧状态。
Args:
meta: 待卸载的插件元数据。
reason: 卸载原因。
purge_modules: 是否在卸载完成后清理插件模块缓存。
"""
await self._invoke_plugin_on_unload(meta)
await self._unregister_plugin(meta.plugin_id, reason)
await self._deactivate_plugin(meta)
self._loader.remove_loaded_plugin(meta.plugin_id)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
if purge_modules:
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
"""收集依赖指定插件的所有已加载插件。
@@ -564,18 +584,52 @@ class PluginRunner:
return list(reversed(load_order))
async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload:
@staticmethod
def _finalize_failed_reload_messages(
failed_plugins: Dict[str, str],
rollback_failures: Dict[str, str],
) -> Dict[str, str]:
"""在重载失败后补充回滚结果说明。"""
finalized_failures: Dict[str, str] = {}
for failed_plugin_id, failure_reason in failed_plugins.items():
rollback_failure = rollback_failures.get(failed_plugin_id)
if rollback_failure:
finalized_failures[failed_plugin_id] = (
f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
)
else:
finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
for failed_plugin_id, rollback_failure in rollback_failures.items():
if failed_plugin_id not in finalized_failures:
finalized_failures[failed_plugin_id] = f"旧版本恢复失败: {rollback_failure}"
return finalized_failures
async def _reload_plugin_by_id(
self,
plugin_id: str,
reason: str,
external_available_plugins: Optional[Set[str]] = None,
) -> ReloadPluginResultPayload:
"""按插件 ID 在 Runner 进程内执行精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
external_available_plugins: 视为已满足的外部依赖插件 ID 集合。
Returns:
ReloadPluginResultPayload: 结构化重载结果。
"""
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
failed_plugins: Dict[str, str] = {}
normalized_external_available = {
str(candidate_plugin_id or "").strip()
for candidate_plugin_id in (external_available_plugins or set())
if str(candidate_plugin_id or "").strip()
}
if plugin_id in duplicate_candidates:
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
@@ -603,29 +657,32 @@ class PluginRunner:
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
unloaded_plugins: List[str] = []
retained_plugin_ids = loaded_plugin_ids - set(unload_order)
rollback_metas: Dict[str, PluginMeta] = {}
for unload_plugin_id in unload_order:
meta = self._loader.get_plugin(unload_plugin_id)
if meta is None:
continue
await self._unload_plugin(meta, reason=reason)
rollback_metas[unload_plugin_id] = meta
await self._unload_plugin(meta, reason=reason, purge_modules=False)
self._loader.purge_plugin_modules(unload_plugin_id, meta.plugin_dir)
unloaded_plugins.append(unload_plugin_id)
reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
for target_plugin_id in target_plugin_ids:
candidate = candidates.get(target_plugin_id)
if candidate is None:
failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态"
failed_plugins[target_plugin_id] = "插件目录已不存在"
continue
reload_candidates[target_plugin_id] = candidate
load_order, dependency_failures = self._loader.resolve_dependencies(
reload_candidates,
extra_available=retained_plugin_ids,
extra_available=retained_plugin_ids | normalized_external_available,
)
failed_plugins.update(dependency_failures)
available_plugins = set(retained_plugin_ids)
available_plugins = set(retained_plugin_ids) | normalized_external_available
reloaded_plugins: List[str] = []
for load_plugin_id in load_order:
@@ -656,7 +713,48 @@ class PluginRunner:
available_plugins.add(load_plugin_id)
reloaded_plugins.append(load_plugin_id)
requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins
if failed_plugins:
rollback_failures: Dict[str, str] = {}
for reloaded_plugin_id in reversed(reloaded_plugins):
reloaded_meta = self._loader.get_plugin(reloaded_plugin_id)
if reloaded_meta is None:
continue
try:
await self._unload_plugin(
reloaded_meta,
reason=f"{reason}_rollback_cleanup",
purge_modules=False,
)
except Exception as exc:
rollback_failures[reloaded_plugin_id] = f"清理失败: {exc}"
finally:
self._loader.purge_plugin_modules(reloaded_plugin_id, reloaded_meta.plugin_dir)
for rollback_plugin_id in reversed(unload_order):
rollback_meta = rollback_metas.get(rollback_plugin_id)
if rollback_meta is None:
continue
try:
restored = await self._activate_plugin(rollback_meta)
except Exception as exc:
rollback_failures[rollback_plugin_id] = str(exc)
continue
if not restored:
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
return ReloadPluginResultPayload(
success=False,
requested_plugin_id=plugin_id,
reloaded_plugins=[],
unloaded_plugins=unloaded_plugins,
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
)
requested_plugin_success = plugin_id in reloaded_plugins
return ReloadPluginResultPayload(
success=requested_plugin_success,
@@ -978,7 +1076,11 @@ class PluginRunner:
)
async with self._reload_lock:
result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason)
result = await self._reload_plugin_by_id(
payload.plugin_id,
payload.reason,
external_available_plugins=set(payload.external_available_plugins),
)
return envelope.make_response(payload=result.model_dump())
def request_capability(self) -> RPCClient:
@@ -1073,6 +1175,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
async def _async_main() -> None:
"""异步主入口"""
host_address = os.environ.get(ENV_IPC_ADDRESS, "")
external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "")
session_token = os.environ.get(ENV_SESSION_TOKEN, "")
plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "")
@@ -1081,11 +1184,24 @@ async def _async_main() -> None:
sys.exit(1)
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
try:
external_plugin_ids = json.loads(external_plugin_ids_raw) if external_plugin_ids_raw else []
except json.JSONDecodeError:
logger.warning("解析外部依赖插件列表失败,已回退为空列表")
external_plugin_ids = []
if not isinstance(external_plugin_ids, list):
logger.warning("外部依赖插件列表格式非法,已回退为空列表")
external_plugin_ids = []
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
_isolate_sys_path(plugin_dirs)
runner = PluginRunner(host_address, session_token, plugin_dirs)
runner = PluginRunner(
host_address,
session_token,
plugin_dirs,
external_available_plugin_ids=[str(plugin_id) for plugin_id in external_plugin_ids],
)
# 注册信号处理
def _mark_runner_shutting_down() -> None: