feat: Enhance plugin configuration management and SDK integration

- Add support for configuration reload scopes in the plugin runtime.
- Implement validation for SDK plugins to ensure required lifecycle methods are overridden.
- Update the configuration update handling to include scope information.
- Introduce tests for expression auto-check task and NapCat adapter SDK integration.
- Refactor configuration management to support callbacks with variable arguments.
- Improve plugin loading and error handling for configuration updates.
- Ensure that plugins can manage their own configuration updates effectively.
This commit is contained in:
DrSmoothl
2026-03-23 20:06:12 +08:00
parent 9dea6b0e6f
commit d13767ee21
16 changed files with 907 additions and 71 deletions

View File

@@ -0,0 +1,89 @@
"""测试表达方式自动检查任务的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_auto_check_engine")
def expression_auto_check_engine_fixture() -> Generator:
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.mark.asyncio
async def test_select_expressions_uses_read_only_session(
monkeypatch: pytest.MonkeyPatch,
expression_auto_check_engine,
) -> None:
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
with Session(expression_auto_check_engine) as session:
session.add(
Expression(
situation="表达情绪高涨或生理反应",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
)
session.commit()
auto_commit_calls: list[bool] = []
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
auto_commit_calls.append(auto_commit)
session = Session(expression_auto_check_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
task = ExpressionAutoCheckTask()
expressions = await task._select_expressions(1)
assert auto_commit_calls == [False]
assert len(expressions) == 1
assert expressions[0].id is not None
assert expressions[0].situation == "表达情绪高涨或生理反应"
assert expressions[0].style == "发送💦表情符号"

View File

@@ -0,0 +1,132 @@
"""NapCat 插件与新 SDK 对接测试。"""
from pathlib import Path
from typing import Any, Dict, List
import importlib
import logging
import sys
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
if import_path not in sys.path:
sys.path.insert(0, import_path)
class _FakeGatewayCapability:
"""用于捕获消息网关状态上报的测试替身。"""
def __init__(self) -> None:
"""初始化测试替身。"""
self.calls: List[Dict[str, Any]] = []
async def update_state(
self,
gateway_name: str,
*,
ready: bool,
platform: str = "",
account_id: str = "",
scope: str = "",
metadata: Dict[str, Any] | None = None,
) -> bool:
"""记录一次状态上报请求。
Args:
gateway_name: 网关组件名称。
ready: 当前是否就绪。
platform: 平台名称。
account_id: 账号 ID。
scope: 路由作用域。
metadata: 附加元数据。
Returns:
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
"""
self.calls.append(
{
"gateway_name": gateway_name,
"ready": ready,
"platform": platform,
"account_id": account_id,
"scope": scope,
"metadata": metadata or {},
}
)
return True
def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
"""动态加载 NapCat 插件测试所需的符号。
Returns:
tuple[Any, Any, Any, Any]:
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
"""
constants_module = importlib.import_module("napcat_adapter.constants")
config_module = importlib.import_module("napcat_adapter.config")
plugin_module = importlib.import_module("napcat_adapter.plugin")
runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
return (
constants_module.NAPCAT_GATEWAY_NAME,
config_module.NapCatServerConfig,
plugin_module.NapCatAdapterPlugin,
runtime_state_module.NapCatRuntimeStateManager,
)
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
"""NapCat 插件应声明新的双工消息网关组件。"""
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
plugin = napcat_plugin_cls()
components = plugin.get_components()
gateway_components = [
component
for component in components
if component.get("type") == "MESSAGE_GATEWAY"
]
assert len(gateway_components) == 1
gateway_component = gateway_components[0]
assert gateway_component["name"] == napcat_gateway_name
assert gateway_component["metadata"]["route_type"] == "duplex"
assert gateway_component["metadata"]["platform"] == "qq"
assert gateway_component["metadata"]["protocol"] == "napcat"
@pytest.mark.asyncio
async def test_runtime_state_reports_via_gateway_capability() -> None:
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
gateway_capability = _FakeGatewayCapability()
runtime_state_manager = runtime_state_cls(
gateway_capability=gateway_capability,
logger=logging.getLogger("test.napcat_adapter"),
gateway_name=napcat_gateway_name,
)
connected = await runtime_state_manager.report_connected(
"10001",
napcat_server_config_cls(connection_id="primary"),
)
await runtime_state_manager.report_disconnected()
assert connected is True
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
assert gateway_capability.calls[0]["ready"] is True
assert gateway_capability.calls[0]["platform"] == "qq"
assert gateway_capability.calls[0]["account_id"] == "10001"
assert gateway_capability.calls[0]["scope"] == "primary"
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
assert gateway_capability.calls[1]["ready"] is False
assert gateway_capability.calls[1]["platform"] == "qq"

View File

@@ -441,8 +441,8 @@ class TestSDK:
def set_plugin_config(self, config):
self.configs.append(config)
async def on_config_update(self, config, version):
self.updates.append((config, version, list(self.configs)))
async def on_config_update(self, scope, config, version):
self.updates.append((scope, config, version, list(self.configs)))
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
plugin = DummyPlugin()
@@ -453,14 +453,60 @@ class TestSDK:
message_type=MessageType.REQUEST,
method="plugin.config_updated",
plugin_id="demo_plugin",
payload={"config_data": {"enabled": True}, "config_version": "v2"},
payload={
"plugin_id": "demo_plugin",
"config_scope": "self",
"config_data": {"enabled": True},
"config_version": "v2",
},
)
response = await runner._handle_config_updated(envelope)
assert response.payload["acknowledged"] is True
assert plugin.configs == [{"enabled": True}]
assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])]
assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])]
@pytest.mark.asyncio
async def test_runner_global_config_update_does_not_override_plugin_config(self):
"""bot/model 广播不应覆盖插件自身配置缓存。"""
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
from src.plugin_runtime.runner.runner_main import PluginRunner
class DummyPlugin:
def __init__(self):
self.configs = []
self.updates = []
def set_plugin_config(self, config):
self.configs.append(config)
async def on_config_update(self, scope, config, version):
self.updates.append((scope, config, version, list(self.configs)))
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
plugin = DummyPlugin()
runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin)
plugin.set_plugin_config({"plugin_enabled": True})
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.config_updated",
plugin_id="demo_plugin",
payload={
"plugin_id": "demo_plugin",
"config_scope": "model",
"config_data": {"models": []},
"config_version": "",
},
)
response = await runner._handle_config_updated(envelope)
assert response.payload["acknowledged"] is True
assert plugin.configs == [{"plugin_enabled": True}]
assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])]
@pytest.mark.asyncio
async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch):
@@ -911,6 +957,120 @@ class TestDependencyResolution:
assert loader.failed_plugins == {}
assert loaded[0].instance.answer() == 42
def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
plugin_dir = plugin_root / "demo_plugin"
plugin_dir.mkdir()
(plugin_dir / "_manifest.json").write_text(
json.dumps(
{
"name": "demo_plugin",
"version": "1.0.0",
"description": "demo",
"author": "tester",
}
),
encoding="utf-8",
)
(plugin_dir / "plugin.py").write_text(
"from maibot_sdk import MaiBotPlugin\n\n"
"class DemoPlugin(MaiBotPlugin):\n"
" async def on_load(self):\n"
" pass\n\n"
" async def on_unload(self):\n"
" pass\n\n"
"def create_plugin():\n"
" return DemoPlugin()\n",
encoding="utf-8",
)
loader = PluginLoader()
loaded = loader.discover_and_load([str(plugin_root)])
assert loaded == []
assert "demo_plugin" in loader.failed_plugins
assert "on_config_update" in loader.failed_plugins["demo_plugin"]
def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
plugin_dir = plugin_root / "demo_plugin"
plugin_dir.mkdir()
(plugin_dir / "_manifest.json").write_text(
json.dumps(
{
"name": "demo_plugin",
"version": "1.0.0",
"description": "demo",
"author": "tester",
}
),
encoding="utf-8",
)
(plugin_dir / "plugin.py").write_text(
"from maibot_sdk import MaiBotPlugin\n\n"
"class DemoPlugin(MaiBotPlugin):\n"
" async def on_unload(self):\n"
" pass\n\n"
" async def on_config_update(self, scope, config_data, version):\n"
" pass\n\n"
"def create_plugin():\n"
" return DemoPlugin()\n",
encoding="utf-8",
)
loader = PluginLoader()
loaded = loader.discover_and_load([str(plugin_root)])
assert loaded == []
assert "demo_plugin" in loader.failed_plugins
assert "on_load" in loader.failed_plugins["demo_plugin"]
def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
plugin_dir = plugin_root / "demo_plugin"
plugin_dir.mkdir()
(plugin_dir / "_manifest.json").write_text(
json.dumps(
{
"name": "demo_plugin",
"version": "1.0.0",
"description": "demo",
"author": "tester",
}
),
encoding="utf-8",
)
(plugin_dir / "plugin.py").write_text(
"from maibot_sdk import MaiBotPlugin\n\n"
"class DemoPlugin(MaiBotPlugin):\n"
" async def on_load(self):\n"
" pass\n\n"
" async def on_config_update(self, scope, config_data, version):\n"
" pass\n\n"
"def create_plugin():\n"
" return DemoPlugin()\n",
encoding="utf-8",
)
loader = PluginLoader()
loaded = loader.discover_and_load([str(plugin_root)])
assert loaded == []
assert "demo_plugin" in loader.failed_plugins
assert "on_unload" in loader.failed_plugins["demo_plugin"]
def test_isolate_sys_path_preserves_plugin_dirs(self):
from src.plugin_runtime.runner import runner_main
@@ -2299,9 +2459,10 @@ class TestIntegration:
assert refresh_calls == [True]
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
@@ -2311,6 +2472,10 @@ class TestIntegration:
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
monkeypatch.chdir(tmp_path)
@@ -2318,31 +2483,95 @@ class TestIntegration:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
self.reload_calls = []
self.config_updates = []
async def reload_plugin(self, plugin_id, reason="manual"):
self.reload_calls.append((plugin_id, reason))
async def notify_plugin_config_updated(
self,
plugin_id,
config_data,
config_version="",
config_scope="self",
):
self.config_updates.append((plugin_id, config_data, config_version, config_scope))
return True
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
refresh_calls = []
def fake_refresh() -> None:
refresh_calls.append(True)
manager._refresh_plugin_config_watch_subscriptions = fake_refresh
await manager._handle_plugin_config_changes(
"alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
assert manager._third_party_supervisor.reload_calls == []
assert refresh_calls == [True]
assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "", "self")]
assert manager._third_party_supervisor.config_updates == []
@pytest.mark.asyncio
async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
class FakeRegistration:
def __init__(self, subscriptions):
self.config_reload_subscriptions = subscriptions
class FakeSupervisor:
def __init__(self, registrations):
self._registered_plugins = registrations
self.config_updates = []
def get_config_reload_subscribers(self, scope):
matched_plugins = []
for plugin_id, registration in self._registered_plugins.items():
if scope in registration.config_reload_subscriptions:
matched_plugins.append(plugin_id)
return matched_plugins
async def notify_plugin_config_updated(
self,
plugin_id,
config_data,
config_version="",
config_scope="self",
):
self.config_updates.append((plugin_id, config_data, config_version, config_scope))
return True
fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True))
monkeypatch.setattr(
integration_module.config_manager,
"get_global_config",
lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime),
)
monkeypatch.setattr(
integration_module.config_manager,
"get_model_config",
lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}),
)
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor(
{
"alpha": FakeRegistration(["bot"]),
"beta": FakeRegistration([]),
}
)
manager._third_party_supervisor = FakeSupervisor(
{
"gamma": FakeRegistration(["model"]),
}
)
await manager._handle_main_config_reload(["bot", "model"])
assert manager._builtin_supervisor.config_updates == [
("alpha", {"bot": {"name": "MaiBot"}}, "", "bot")
]
assert manager._third_party_supervisor.config_updates == [
("gamma", {"models": [{"name": "demo"}]}, "", "model")
]
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
from src.plugin_runtime import integration as integration_module