diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index f9d66e10..7c383cbb 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -6,6 +6,7 @@ from types import SimpleNamespace import asyncio +import json import os import sys @@ -871,6 +872,63 @@ class TestDependencyResolution: order, failed = loader._resolve_dependencies(candidates) assert len(failed) >= 1 # 至少一个循环插件被标记 + def test_loader_supports_package_imports_inside_create_plugin(self, tmp_path): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + plugin_dir = plugin_root / "grok_search_plugin" + plugin_dir.mkdir() + + (plugin_dir / "_manifest.json").write_text( + json.dumps( + { + "name": "grok_search_plugin", + "version": "1.0.0", + "description": "demo", + "author": "tester", + } + ), + encoding="utf-8", + ) + (plugin_dir / "__init__.py").write_text("VALUE = 1\n", encoding="utf-8") + (plugin_dir / "services.py").write_text("def answer():\n return 42\n", encoding="utf-8") + (plugin_dir / "plugin.py").write_text( + "class DemoPlugin:\n" + " pass\n\n" + "def create_plugin():\n" + " from grok_search_plugin.services import answer\n" + " plugin = DemoPlugin()\n" + " plugin.answer = answer\n" + " return plugin\n", + encoding="utf-8", + ) + + loader = PluginLoader() + loaded = loader.discover_and_load([str(plugin_root)]) + + assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"] + assert loader.failed_plugins == {} + assert loaded[0].instance.answer() == 42 + + def test_isolate_sys_path_preserves_plugin_dirs(self): + from src.plugin_runtime.runner import runner_main + + plugin_root = os.path.normpath("/tmp/maibot-plugin-root") + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + if plugin_root in sys.path: + sys.path.remove(plugin_root) + + runner_main._isolate_sys_path([plugin_root]) + + assert plugin_root in sys.path + finally: + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + # ─── Host-side ComponentRegistry 测试 ────────────────────── diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 9efe184d..1b055dfe 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -241,47 +241,59 @@ class PluginLoader: module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module - plugin_parent_dir = os.path.dirname(plugin_dir) - inserted_plugin_parent = False - if plugin_parent_dir and plugin_parent_dir not in sys.path: - sys.path.insert(0, plugin_parent_dir) - inserted_plugin_parent = True - - try: + plugin_parent_dir = os.path.normpath(os.path.dirname(plugin_dir)) + with self._temporary_sys_path_entry(plugin_parent_dir): spec.loader.exec_module(module) - finally: - if inserted_plugin_parent: - with contextlib.suppress(ValueError): - sys.path.remove(plugin_parent_dir) - # 优先使用新版 create_plugin 工厂函数 - create_plugin = getattr(module, "create_plugin", None) - if create_plugin is not None: - instance = create_plugin() - logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=plugin_dir, - plugin_instance=instance, - manifest=manifest, - ) + # 优先使用新版 create_plugin 工厂函数 + create_plugin = getattr(module, "create_plugin", None) + if create_plugin is not None: + instance = create_plugin() + logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=plugin_dir, + plugin_instance=instance, + manifest=manifest, + ) - # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 - instance = self._try_load_legacy_plugin(module, plugin_id) - if instance is not None: - logger.info( - f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" - ) - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=plugin_dir, - plugin_instance=instance, - manifest=manifest, - ) + # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 + instance = self._try_load_legacy_plugin(module, plugin_id) + if instance is not None: + logger.info( + f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" + ) + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=plugin_dir, + plugin_instance=instance, + manifest=manifest, + ) logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin") return None + @staticmethod + @contextlib.contextmanager + def _temporary_sys_path_entry(path: str): + """临时将路径放入 sys.path 头部,并在离开作用域后恢复。""" + if not path: + yield + return + + normalized_path = os.path.normpath(path) + existing_paths = {os.path.normpath(entry) for entry in sys.path} + inserted = normalized_path not in existing_paths + if inserted: + sys.path.insert(0, normalized_path) + + try: + yield + finally: + if inserted: + with contextlib.suppress(ValueError): + sys.path.remove(normalized_path) + # ──── 旧版插件兼容 ──────────────────────────────────────── def _ensure_compat_hook(self) -> None: