feat: 添加批量插件重载功能及相关测试
This commit is contained in:
@@ -676,6 +676,101 @@ class TestSDK:
|
||||
methods = [call["method"] for call in runner._rpc_client.calls]
|
||||
assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch):
|
||||
"""批量重载应只对重叠依赖闭包执行一次 unload/load。"""
|
||||
from src.plugin_runtime.runner.runner_main import PluginRunner
|
||||
|
||||
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
|
||||
plugin_a_id = "test.plugin-a"
|
||||
plugin_b_id = "test.plugin-b"
|
||||
plugin_c_id = "test.plugin-c"
|
||||
|
||||
def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
plugin_id=plugin_id,
|
||||
dependencies=dependencies,
|
||||
plugin_dir=f"/tmp/{plugin_id}",
|
||||
version="1.0.0",
|
||||
instance=SimpleNamespace(),
|
||||
)
|
||||
|
||||
loaded_metas = {
|
||||
plugin_a_id: build_meta(plugin_a_id, []),
|
||||
plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]),
|
||||
plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]),
|
||||
}
|
||||
reloaded_metas = {
|
||||
plugin_id: build_meta(plugin_id, list(meta.dependencies))
|
||||
for plugin_id, meta in loaded_metas.items()
|
||||
}
|
||||
candidates = {
|
||||
plugin_a_id: (
|
||||
"dir_plugin_a",
|
||||
build_test_manifest_model(plugin_a_id),
|
||||
"plugin_a/plugin.py",
|
||||
),
|
||||
plugin_b_id: (
|
||||
"dir_plugin_b",
|
||||
build_test_manifest_model(
|
||||
plugin_b_id,
|
||||
dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}],
|
||||
),
|
||||
"plugin_b/plugin.py",
|
||||
),
|
||||
plugin_c_id: (
|
||||
"dir_plugin_c",
|
||||
build_test_manifest_model(
|
||||
plugin_c_id,
|
||||
dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}],
|
||||
),
|
||||
"plugin_c/plugin.py",
|
||||
),
|
||||
}
|
||||
unloaded_plugins: list[str] = []
|
||||
activated_plugins: list[str] = []
|
||||
|
||||
monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {}))
|
||||
monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys()))
|
||||
monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id))
|
||||
monkeypatch.setattr(
|
||||
runner._loader,
|
||||
"remove_loaded_plugin",
|
||||
lambda plugin_id: loaded_metas.pop(plugin_id, None),
|
||||
)
|
||||
monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: [])
|
||||
monkeypatch.setattr(
|
||||
runner._loader,
|
||||
"resolve_dependencies",
|
||||
lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
runner._loader,
|
||||
"load_candidate",
|
||||
lambda plugin_id, candidate: reloaded_metas[plugin_id],
|
||||
)
|
||||
|
||||
async def fake_unload_plugin(meta, reason, purge_modules=False):
|
||||
del reason, purge_modules
|
||||
unloaded_plugins.append(meta.plugin_id)
|
||||
loaded_metas.pop(meta.plugin_id, None)
|
||||
|
||||
async def fake_activate_plugin(meta):
|
||||
activated_plugins.append(meta.plugin_id)
|
||||
loaded_metas[meta.plugin_id] = meta
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin)
|
||||
monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin)
|
||||
|
||||
result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual")
|
||||
|
||||
assert result.success is True
|
||||
assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id]
|
||||
assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id]
|
||||
assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id]
|
||||
assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id]
|
||||
|
||||
|
||||
class TestPluginSdkUsage:
|
||||
"""验证仓库内插件按新 SDK 归一化返回值工作。"""
|
||||
@@ -1220,6 +1315,25 @@ class TestDependencyResolution:
|
||||
sys.path[:] = original_path
|
||||
sys.meta_path[:] = original_meta_path
|
||||
|
||||
def test_isolate_sys_path_blocks_disallowed_src_imports(self):
|
||||
import importlib
|
||||
|
||||
from src.plugin_runtime.runner import runner_main
|
||||
|
||||
original_path = list(sys.path)
|
||||
original_meta_path = list(sys.meta_path)
|
||||
sys.modules.pop("src.forbidden_demo", None)
|
||||
|
||||
try:
|
||||
runner_main._isolate_sys_path([])
|
||||
|
||||
with pytest.raises(ImportError, match="不允许导入主程序模块"):
|
||||
importlib.import_module("src.forbidden_demo")
|
||||
finally:
|
||||
sys.path[:] = original_path
|
||||
sys.meta_path[:] = original_meta_path
|
||||
sys.modules.pop("src.forbidden_demo", None)
|
||||
|
||||
|
||||
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
||||
|
||||
@@ -1264,6 +1378,30 @@ class TestComponentRegistry:
|
||||
assert stats["command"] == 1
|
||||
assert stats["tool"] == 1
|
||||
|
||||
def test_register_command_with_invalid_regex_only_warns(self, monkeypatch):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
warnings: list[str] = []
|
||||
monkeypatch.setattr(
|
||||
"src.plugin_runtime.host.component_registry.logger.warning",
|
||||
lambda message: warnings.append(str(message)),
|
||||
)
|
||||
|
||||
success = reg.register_component(
|
||||
"broken",
|
||||
"command",
|
||||
"plugin_a",
|
||||
{
|
||||
"command_pattern": "[",
|
||||
},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert reg.get_component("plugin_a.broken") is not None
|
||||
assert warnings
|
||||
assert "plugin_a.broken" in warnings[0]
|
||||
|
||||
def test_query_by_type(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
@@ -2303,6 +2441,39 @@ class TestSupervisor:
|
||||
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||||
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self):
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
sent_requests: list[tuple[str, dict[str, object], int]] = []
|
||||
|
||||
class FakeRPCServer:
|
||||
async def send_request(self, method, payload, timeout_ms=5000, **kwargs):
|
||||
del kwargs
|
||||
sent_requests.append((method, payload, timeout_ms))
|
||||
return SimpleNamespace(
|
||||
payload=ReloadPluginsResultPayload(
|
||||
success=True,
|
||||
requested_plugin_ids=["plugin_a", "plugin_b"],
|
||||
reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"],
|
||||
unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"],
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
supervisor._rpc_server = FakeRPCServer()
|
||||
|
||||
reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual")
|
||||
|
||||
assert reloaded is True
|
||||
assert len(sent_requests) == 1
|
||||
method, payload, timeout_ms = sent_requests[0]
|
||||
assert method == "plugin.reload_batch"
|
||||
assert payload["plugin_ids"] == ["plugin_a", "plugin_b"]
|
||||
assert payload["reason"] == "manual"
|
||||
assert timeout_ms >= 10000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch):
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
@@ -75,14 +75,14 @@ class CommandEntry(ComponentEntry):
|
||||
"""Command 组件条目"""
|
||||
|
||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
self.compiled_pattern: Optional[re.Pattern] = None
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
self.aliases: List[str] = metadata.get("aliases", [])
|
||||
self.compiled_pattern: Optional[re.Pattern] = None
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
self.compiled_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
except (re.error, TypeError) as e:
|
||||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||||
super().__init__(name, component_type, plugin_id, metadata)
|
||||
|
||||
|
||||
class ToolEntry(ComponentEntry):
|
||||
|
||||
@@ -33,6 +33,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ReceiveExternalMessageResultPayload,
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
ReloadPluginsPayload,
|
||||
ReloadPluginsResultPayload,
|
||||
RouteMessagePayload,
|
||||
RunnerReadyPayload,
|
||||
ShutdownPayload,
|
||||
@@ -194,6 +196,35 @@ class PluginRunnerSupervisor:
|
||||
for plugin_id, registration in self._registered_plugins.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]:
|
||||
"""规范化批量重载入参。
|
||||
|
||||
Args:
|
||||
plugin_ids: 原始插件 ID 列表或单个插件 ID。
|
||||
|
||||
Returns:
|
||||
List[str]: 去重且去空白后的插件 ID 列表。
|
||||
"""
|
||||
|
||||
raw_plugin_ids: List[str]
|
||||
if plugin_ids is None:
|
||||
raw_plugin_ids = []
|
||||
elif isinstance(plugin_ids, str):
|
||||
raw_plugin_ids = [plugin_ids]
|
||||
else:
|
||||
raw_plugin_ids = list(plugin_ids)
|
||||
|
||||
normalized_plugin_ids: List[str] = []
|
||||
seen_plugin_ids: set[str] = set()
|
||||
for plugin_id in raw_plugin_ids:
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
|
||||
continue
|
||||
seen_plugin_ids.add(normalized_plugin_id)
|
||||
normalized_plugin_ids.append(normalized_plugin_id)
|
||||
return normalized_plugin_ids
|
||||
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
@@ -420,7 +451,7 @@ class PluginRunnerSupervisor:
|
||||
|
||||
async def reload_plugins(
|
||||
self,
|
||||
plugin_ids: Optional[List[str]] = None,
|
||||
plugin_ids: Optional[List[str] | str] = None,
|
||||
reason: str = "manual",
|
||||
external_available_plugins: Optional[Dict[str, str]] = None,
|
||||
) -> bool:
|
||||
@@ -434,19 +465,37 @@ class PluginRunnerSupervisor:
|
||||
Returns:
|
||||
bool: 是否全部重载成功。
|
||||
"""
|
||||
target_plugin_ids = plugin_ids or list(self._registered_plugins.keys())
|
||||
ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids))
|
||||
success = True
|
||||
ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids)
|
||||
if not ordered_plugin_ids:
|
||||
ordered_plugin_ids = list(self._registered_plugins.keys())
|
||||
if not ordered_plugin_ids:
|
||||
return True
|
||||
|
||||
for plugin_id in ordered_plugin_ids:
|
||||
reloaded = await self.reload_plugin(
|
||||
plugin_id=plugin_id,
|
||||
if len(ordered_plugin_ids) == 1:
|
||||
return await self.reload_plugin(
|
||||
plugin_id=ordered_plugin_ids[0],
|
||||
reason=reason,
|
||||
external_available_plugins=external_available_plugins,
|
||||
)
|
||||
success = success and reloaded
|
||||
|
||||
return success
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.reload_batch",
|
||||
payload=ReloadPluginsPayload(
|
||||
plugin_ids=ordered_plugin_ids,
|
||||
reason=reason,
|
||||
external_available_plugins=external_available_plugins or self._external_available_plugins,
|
||||
).model_dump(),
|
||||
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件批量重载请求失败: {exc}")
|
||||
return False
|
||||
|
||||
result = ReloadPluginsResultPayload.model_validate(response.payload)
|
||||
if not result.success:
|
||||
logger.warning(f"插件批量重载失败: {result.failed_plugins}")
|
||||
return result.success
|
||||
|
||||
async def notify_plugin_config_updated(
|
||||
self,
|
||||
|
||||
@@ -289,6 +289,20 @@ class ReloadPluginPayload(BaseModel):
|
||||
"""可视为已满足的外部依赖插件版本映射"""
|
||||
|
||||
|
||||
class ReloadPluginsPayload(BaseModel):
|
||||
"""批量插件重载请求载荷。"""
|
||||
|
||||
plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
|
||||
"""目标插件 ID 列表"""
|
||||
reason: str = Field(default="manual", description="重载原因")
|
||||
"""重载原因"""
|
||||
external_available_plugins: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="可视为已满足的外部依赖插件版本映射",
|
||||
)
|
||||
"""可视为已满足的外部依赖插件版本映射"""
|
||||
|
||||
|
||||
class ReloadPluginResultPayload(BaseModel):
|
||||
"""插件重载结果载荷。"""
|
||||
|
||||
@@ -304,6 +318,21 @@ class ReloadPluginResultPayload(BaseModel):
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class ReloadPluginsResultPayload(BaseModel):
|
||||
"""批量插件重载结果载荷。"""
|
||||
|
||||
success: bool = Field(description="是否重载成功")
|
||||
"""是否重载成功"""
|
||||
requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
|
||||
"""请求重载的插件 ID 列表"""
|
||||
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
|
||||
"""成功完成重载的插件列表"""
|
||||
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
|
||||
"""本次已卸载的插件列表"""
|
||||
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class MessageGatewayStateUpdatePayload(BaseModel):
|
||||
"""消息网关运行时状态更新载荷。"""
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
ReloadPluginsPayload,
|
||||
ReloadPluginsResultPayload,
|
||||
RunnerReadyPayload,
|
||||
UnregisterPluginPayload,
|
||||
)
|
||||
@@ -334,6 +336,7 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
|
||||
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
|
||||
self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str:
|
||||
@@ -597,6 +600,21 @@ class PluginRunner:
|
||||
|
||||
return impacted_plugins
|
||||
|
||||
def _collect_reverse_dependents_for_roots(self, plugin_ids: Set[str]) -> Set[str]:
|
||||
"""收集多个根插件对应的反向依赖并集。
|
||||
|
||||
Args:
|
||||
plugin_ids: 根插件 ID 集合。
|
||||
|
||||
Returns:
|
||||
Set[str]: 所有根插件及其反向依赖并集。
|
||||
"""
|
||||
|
||||
impacted_plugins: Set[str] = set()
|
||||
for plugin_id in sorted(plugin_ids):
|
||||
impacted_plugins.update(self._collect_reverse_dependents(plugin_id))
|
||||
return impacted_plugins
|
||||
|
||||
def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]:
|
||||
"""构建受影响插件的卸载顺序。
|
||||
|
||||
@@ -635,6 +653,20 @@ class PluginRunner:
|
||||
|
||||
return list(reversed(load_order))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_requested_plugin_ids(plugin_ids: List[str]) -> List[str]:
|
||||
"""规范化批量重载请求中的插件 ID 列表。"""
|
||||
|
||||
normalized_plugin_ids: List[str] = []
|
||||
seen_plugin_ids: Set[str] = set()
|
||||
for plugin_id in plugin_ids:
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
|
||||
continue
|
||||
seen_plugin_ids.add(normalized_plugin_id)
|
||||
normalized_plugin_ids.append(normalized_plugin_id)
|
||||
return normalized_plugin_ids
|
||||
|
||||
@staticmethod
|
||||
def _finalize_failed_reload_messages(
|
||||
failed_plugins: Dict[str, str],
|
||||
@@ -674,6 +706,31 @@ class PluginRunner:
|
||||
Returns:
|
||||
ReloadPluginResultPayload: 结构化重载结果。
|
||||
"""
|
||||
batch_result = await self._reload_plugins_by_ids(
|
||||
[plugin_id],
|
||||
reason,
|
||||
external_available_plugins=external_available_plugins,
|
||||
)
|
||||
return ReloadPluginResultPayload(
|
||||
success=batch_result.success,
|
||||
requested_plugin_id=plugin_id,
|
||||
reloaded_plugins=batch_result.reloaded_plugins,
|
||||
unloaded_plugins=batch_result.unloaded_plugins,
|
||||
failed_plugins=batch_result.failed_plugins,
|
||||
)
|
||||
|
||||
async def _reload_plugins_by_ids(
|
||||
self,
|
||||
plugin_ids: List[str],
|
||||
reason: str,
|
||||
external_available_plugins: Optional[Dict[str, str]] = None,
|
||||
) -> ReloadPluginsResultPayload:
|
||||
"""按插件 ID 列表在 Runner 进程内执行一次批量重载。"""
|
||||
|
||||
normalized_plugin_ids = self._normalize_requested_plugin_ids(plugin_ids)
|
||||
if not normalized_plugin_ids:
|
||||
return ReloadPluginsResultPayload(success=True, requested_plugin_ids=[])
|
||||
|
||||
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
|
||||
failed_plugins: Dict[str, str] = {}
|
||||
normalized_external_available = {
|
||||
@@ -682,28 +739,35 @@ class PluginRunner:
|
||||
if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip()
|
||||
}
|
||||
|
||||
if plugin_id in duplicate_candidates:
|
||||
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
||||
return ReloadPluginResultPayload(
|
||||
success=False,
|
||||
requested_plugin_id=plugin_id,
|
||||
failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"},
|
||||
)
|
||||
|
||||
loaded_plugin_ids = set(self._loader.list_plugins())
|
||||
plugin_is_loaded = plugin_id in loaded_plugin_ids
|
||||
plugin_has_candidate = plugin_id in candidates
|
||||
reload_root_ids: Set[str] = set()
|
||||
for plugin_id in normalized_plugin_ids:
|
||||
if plugin_id in duplicate_candidates:
|
||||
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
|
||||
failed_plugins[plugin_id] = f"检测到重复插件 ID: {conflict_paths}"
|
||||
continue
|
||||
|
||||
if not plugin_is_loaded and not plugin_has_candidate:
|
||||
return ReloadPluginResultPayload(
|
||||
plugin_is_loaded = plugin_id in loaded_plugin_ids
|
||||
plugin_has_candidate = plugin_id in candidates
|
||||
if not plugin_is_loaded and not plugin_has_candidate:
|
||||
failed_plugins[plugin_id] = "插件不存在或未找到合法的 manifest/plugin.py"
|
||||
continue
|
||||
|
||||
reload_root_ids.add(plugin_id)
|
||||
|
||||
if not reload_root_ids:
|
||||
return ReloadPluginsResultPayload(
|
||||
success=False,
|
||||
requested_plugin_id=plugin_id,
|
||||
failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"},
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
failed_plugins=failed_plugins,
|
||||
)
|
||||
|
||||
target_plugin_ids: Set[str] = {plugin_id}
|
||||
if plugin_is_loaded:
|
||||
target_plugin_ids = self._collect_reverse_dependents(plugin_id)
|
||||
target_plugin_ids: Set[str] = {
|
||||
plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
|
||||
}
|
||||
loaded_root_plugin_ids = reload_root_ids & loaded_plugin_ids
|
||||
if loaded_root_plugin_ids:
|
||||
target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
|
||||
|
||||
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
|
||||
unloaded_plugins: List[str] = []
|
||||
@@ -813,19 +877,19 @@ class PluginRunner:
|
||||
if not restored:
|
||||
rollback_failures[rollback_plugin_id] = "无法重新激活旧版本"
|
||||
|
||||
return ReloadPluginResultPayload(
|
||||
return ReloadPluginsResultPayload(
|
||||
success=False,
|
||||
requested_plugin_id=plugin_id,
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
reloaded_plugins=[],
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures),
|
||||
)
|
||||
|
||||
requested_plugin_success = plugin_id in reloaded_plugins
|
||||
requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids)
|
||||
|
||||
return ReloadPluginResultPayload(
|
||||
success=requested_plugin_success,
|
||||
requested_plugin_id=plugin_id,
|
||||
return ReloadPluginsResultPayload(
|
||||
success=requested_plugin_success and not failed_plugins,
|
||||
requested_plugin_ids=normalized_plugin_ids,
|
||||
reloaded_plugins=reloaded_plugins,
|
||||
unloaded_plugins=unloaded_plugins,
|
||||
failed_plugins=failed_plugins,
|
||||
@@ -1139,6 +1203,29 @@ class PluginRunner:
|
||||
)
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
async def _handle_reload_plugins(self, envelope: Envelope) -> Envelope:
|
||||
"""处理批量插件重载请求。"""
|
||||
|
||||
try:
|
||||
payload = ReloadPluginsPayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
if self._reload_lock.locked():
|
||||
requested_plugin_ids = ", ".join(self._normalize_requested_plugin_ids(payload.plugin_ids)) or "<empty>"
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_RELOAD_IN_PROGRESS.value,
|
||||
f"插件 {requested_plugin_ids} 批量重载请求被拒绝:已有重载任务正在执行",
|
||||
)
|
||||
|
||||
async with self._reload_lock:
|
||||
result = await self._reload_plugins_by_ids(
|
||||
list(payload.plugin_ids),
|
||||
payload.reason,
|
||||
external_available_plugins=dict(payload.external_available_plugins),
|
||||
)
|
||||
return envelope.make_response(payload=result.model_dump())
|
||||
|
||||
def request_capability(self) -> RPCClient:
|
||||
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
|
||||
return self._rpc_client
|
||||
@@ -1153,6 +1240,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
防止插件代码 import 主程序模块读取运行时数据。
|
||||
"""
|
||||
import importlib.abc
|
||||
from importlib.machinery import ModuleSpec
|
||||
import sysconfig
|
||||
|
||||
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
|
||||
@@ -1195,6 +1283,20 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
|
||||
# 安装 import 钩子,阻止插件导入主程序核心模块
|
||||
# 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包
|
||||
class _BlockedSrcModuleLoader(importlib.abc.Loader):
|
||||
"""阻止被 Runner 允许列表之外的主程序模块完成导入。"""
|
||||
|
||||
def __init__(self, fullname: str) -> None:
|
||||
self._fullname = fullname
|
||||
|
||||
def create_module(self, spec: ModuleSpec) -> None:
|
||||
del spec
|
||||
return None
|
||||
|
||||
def exec_module(self, module: Any) -> None:
|
||||
del module
|
||||
raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}")
|
||||
|
||||
class _PluginImportBlocker(importlib.abc.MetaPathFinder):
|
||||
"""阻止 Runner 子进程导入主程序核心模块。
|
||||
|
||||
@@ -1203,14 +1305,15 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
"""
|
||||
|
||||
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
|
||||
__maibot_runner_plugin_import_blocker__ = True
|
||||
|
||||
def find_module(self, fullname: str, path: Any = None) -> Any:
|
||||
def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None:
|
||||
"""决定是否拦截指定模块导入。"""
|
||||
return self if self._should_block(fullname) else None
|
||||
|
||||
def load_module(self, fullname: str) -> None:
|
||||
"""阻止被拦截模块继续导入。"""
|
||||
raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
|
||||
del path, target
|
||||
if not self._should_block(fullname):
|
||||
return None
|
||||
# Python 3.13+/3.14 会优先走 find_spec,不再依赖 find_module。
|
||||
return ModuleSpec(fullname, _BlockedSrcModuleLoader(fullname), is_package=True)
|
||||
|
||||
def _should_block(self, fullname: str) -> bool:
|
||||
"""判断给定模块名是否应被阻止导入。"""
|
||||
@@ -1222,6 +1325,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES
|
||||
)
|
||||
|
||||
sys.meta_path[:] = [
|
||||
finder
|
||||
for finder in sys.meta_path
|
||||
if not getattr(finder, "__maibot_runner_plugin_import_blocker__", False)
|
||||
]
|
||||
sys.meta_path.insert(0, _PluginImportBlocker())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user