feat: Enhance plugin configuration management and SDK integration
- Add support for configuration reload scopes in the plugin runtime. - Implement validation for SDK plugins to ensure required lifecycle methods are overridden. - Update the configuration update handling to include scope information. - Introduce tests for expression auto-check task and NapCat adapter SDK integration. - Refactor configuration management to support callbacks with variable arguments. - Improve plugin loading and error handling for configuration updates. - Ensure that plugins can manage their own configuration updates effectively.
This commit is contained in:
@@ -3,12 +3,18 @@
|
|||||||
通过 /chat 命令设置和查看聊天频率。
|
通过 /chat 命令设置和查看聊天频率。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from maibot_sdk import MaiBotPlugin, Command
|
from maibot_sdk import Command, MaiBotPlugin
|
||||||
|
|
||||||
|
|
||||||
class BetterFrequencyPlugin(MaiBotPlugin):
|
class BetterFrequencyPlugin(MaiBotPlugin):
|
||||||
"""聊天频率控制插件"""
|
"""聊天频率控制插件"""
|
||||||
|
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""处理插件加载。"""
|
||||||
|
|
||||||
|
async def on_unload(self) -> None:
|
||||||
|
"""处理插件卸载。"""
|
||||||
|
|
||||||
@Command(
|
@Command(
|
||||||
"set_talk_frequency",
|
"set_talk_frequency",
|
||||||
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
|
description="设置当前聊天的talk_frequency值:/chat talk_frequency <数字> 或 /chat t <数字>",
|
||||||
@@ -80,6 +86,25 @@ class BetterFrequencyPlugin(MaiBotPlugin):
|
|||||||
await self.ctx.send.text(status_msg, stream_id)
|
await self.ctx.send.text(status_msg, stream_id)
|
||||||
return True, None, False
|
return True, None, False
|
||||||
|
|
||||||
|
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||||
|
"""处理配置热重载事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围。
|
||||||
|
config_data: 最新配置数据。
|
||||||
|
version: 配置版本号。
|
||||||
|
"""
|
||||||
|
|
||||||
|
del scope
|
||||||
|
del config_data
|
||||||
|
del version
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin() -> BetterFrequencyPlugin:
|
||||||
|
"""创建聊天频率插件实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BetterFrequencyPlugin: 新的聊天频率插件实例。
|
||||||
|
"""
|
||||||
|
|
||||||
def create_plugin():
|
|
||||||
return BetterFrequencyPlugin()
|
return BetterFrequencyPlugin()
|
||||||
|
|||||||
@@ -3,17 +3,23 @@
|
|||||||
通过 /emoji 命令管理表情包的添加、列表和删除。
|
通过 /emoji 命令管理表情包的添加、列表和删除。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from maibot_sdk import Command, MaiBotPlugin
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from maibot_sdk import MaiBotPlugin, Command
|
|
||||||
|
|
||||||
|
|
||||||
class EmojiManagePlugin(MaiBotPlugin):
|
class EmojiManagePlugin(MaiBotPlugin):
|
||||||
"""表情包管理插件"""
|
"""表情包管理插件"""
|
||||||
|
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""处理插件加载。"""
|
||||||
|
|
||||||
|
async def on_unload(self) -> None:
|
||||||
|
"""处理插件卸载。"""
|
||||||
|
|
||||||
# ===== 工具方法 =====
|
# ===== 工具方法 =====
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -208,6 +214,25 @@ class EmojiManagePlugin(MaiBotPlugin):
|
|||||||
await self.ctx.send.forward(messages, stream_id)
|
await self.ctx.send.forward(messages, stream_id)
|
||||||
return True, "已发送随机表情包", True
|
return True, "已发送随机表情包", True
|
||||||
|
|
||||||
|
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||||
|
"""处理配置热重载事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围。
|
||||||
|
config_data: 最新配置数据。
|
||||||
|
version: 配置版本号。
|
||||||
|
"""
|
||||||
|
|
||||||
|
del scope
|
||||||
|
del config_data
|
||||||
|
del version
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin() -> EmojiManagePlugin:
|
||||||
|
"""创建表情包管理插件实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmojiManagePlugin: 新的表情包管理插件实例。
|
||||||
|
"""
|
||||||
|
|
||||||
def create_plugin():
|
|
||||||
return EmojiManagePlugin()
|
return EmojiManagePlugin()
|
||||||
|
|||||||
@@ -3,16 +3,22 @@
|
|||||||
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
|
你的第一个 MaiCore 插件,包含问候功能、时间查询等基础示例。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from maibot_sdk import Action, Command, EventHandler, MaiBotPlugin, Tool
|
||||||
|
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
|
|
||||||
from maibot_sdk.types import ActivationType, EventType, ToolParameterInfo, ToolParamType
|
|
||||||
|
|
||||||
|
|
||||||
class HelloWorldPlugin(MaiBotPlugin):
|
class HelloWorldPlugin(MaiBotPlugin):
|
||||||
"""Hello World 示例插件"""
|
"""Hello World 示例插件"""
|
||||||
|
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""处理插件加载。"""
|
||||||
|
|
||||||
|
async def on_unload(self) -> None:
|
||||||
|
"""处理插件卸载。"""
|
||||||
|
|
||||||
# ===== Tool 组件 =====
|
# ===== Tool 组件 =====
|
||||||
|
|
||||||
@Tool(
|
@Tool(
|
||||||
@@ -146,6 +152,25 @@ class HelloWorldPlugin(MaiBotPlugin):
|
|||||||
|
|
||||||
return True, True, None, None, None
|
return True, True, None, None, None
|
||||||
|
|
||||||
|
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||||
|
"""处理配置热重载事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围。
|
||||||
|
config_data: 最新配置数据。
|
||||||
|
version: 配置版本号。
|
||||||
|
"""
|
||||||
|
|
||||||
|
del scope
|
||||||
|
del config_data
|
||||||
|
del version
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin() -> HelloWorldPlugin:
|
||||||
|
"""创建 Hello World 示例插件实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HelloWorldPlugin: 新的示例插件实例。
|
||||||
|
"""
|
||||||
|
|
||||||
def create_plugin():
|
|
||||||
return HelloWorldPlugin()
|
return HelloWorldPlugin()
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ dev = [
|
|||||||
[tool.uv]
|
[tool.uv]
|
||||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
maibot-plugin-sdk = { path = "packages/maibot-plugin-sdk", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|
||||||
|
|||||||
89
pytests/common_test/test_expression_auto_check_task.py
Normal file
89
pytests/common_test/test_expression_auto_check_task.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""测试表达方式自动检查任务的数据库读取行为。"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from sqlmodel import Session, SQLModel, create_engine
|
||||||
|
|
||||||
|
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||||
|
from src.common.database.database_model import Expression
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="expression_auto_check_engine")
|
||||||
|
def expression_auto_check_engine_fixture() -> Generator:
|
||||||
|
"""创建用于表达方式自动检查任务测试的内存数据库引擎。
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generator: 供测试使用的 SQLite 内存引擎。
|
||||||
|
"""
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite://",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_select_expressions_uses_read_only_session(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
expression_auto_check_engine,
|
||||||
|
) -> None:
|
||||||
|
"""选择表达方式时应使用只读会话,并在离开会话后安全读取 ORM 字段。"""
|
||||||
|
|
||||||
|
import src.bw_learner.expression_auto_check_task as expression_auto_check_task_module
|
||||||
|
|
||||||
|
with Session(expression_auto_check_engine) as session:
|
||||||
|
session.add(
|
||||||
|
Expression(
|
||||||
|
situation="表达情绪高涨或生理反应",
|
||||||
|
style="发送💦表情符号",
|
||||||
|
content_list='["表达情绪高涨或生理反应"]',
|
||||||
|
count=1,
|
||||||
|
session_id="session-a",
|
||||||
|
checked=False,
|
||||||
|
rejected=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
auto_commit_calls: list[bool] = []
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||||
|
"""构造带自动提交语义的测试会话工厂。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_commit: 退出上下文时是否自动提交。
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generator[Session, None, None]: SQLModel 会话对象。
|
||||||
|
"""
|
||||||
|
|
||||||
|
auto_commit_calls.append(auto_commit)
|
||||||
|
session = Session(expression_auto_check_engine)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
if auto_commit:
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
monkeypatch.setattr(expression_auto_check_task_module, "get_db_session", fake_get_db_session)
|
||||||
|
monkeypatch.setattr(expression_auto_check_task_module.random, "sample", lambda entries, _count: list(entries))
|
||||||
|
|
||||||
|
task = ExpressionAutoCheckTask()
|
||||||
|
expressions = await task._select_expressions(1)
|
||||||
|
|
||||||
|
assert auto_commit_calls == [False]
|
||||||
|
assert len(expressions) == 1
|
||||||
|
assert expressions[0].id is not None
|
||||||
|
assert expressions[0].situation == "表达情绪高涨或生理反应"
|
||||||
|
assert expressions[0].style == "发送💦表情符号"
|
||||||
132
pytests/test_napcat_adapter_sdk.py
Normal file
132
pytests/test_napcat_adapter_sdk.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""NapCat 插件与新 SDK 对接测试。"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
PLUGINS_ROOT = PROJECT_ROOT / "plugins"
|
||||||
|
SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk"
|
||||||
|
|
||||||
|
for import_path in (str(PLUGINS_ROOT), str(SDK_ROOT)):
|
||||||
|
if import_path not in sys.path:
|
||||||
|
sys.path.insert(0, import_path)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeGatewayCapability:
|
||||||
|
"""用于捕获消息网关状态上报的测试替身。"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""初始化测试替身。"""
|
||||||
|
|
||||||
|
self.calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def update_state(
|
||||||
|
self,
|
||||||
|
gateway_name: str,
|
||||||
|
*,
|
||||||
|
ready: bool,
|
||||||
|
platform: str = "",
|
||||||
|
account_id: str = "",
|
||||||
|
scope: str = "",
|
||||||
|
metadata: Dict[str, Any] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""记录一次状态上报请求。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gateway_name: 网关组件名称。
|
||||||
|
ready: 当前是否就绪。
|
||||||
|
platform: 平台名称。
|
||||||
|
account_id: 账号 ID。
|
||||||
|
scope: 路由作用域。
|
||||||
|
metadata: 附加元数据。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 始终返回 ``True``,模拟 Host 接受状态更新。
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"gateway_name": gateway_name,
|
||||||
|
"ready": ready,
|
||||||
|
"platform": platform,
|
||||||
|
"account_id": account_id,
|
||||||
|
"scope": scope,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _load_napcat_sdk_symbols() -> tuple[Any, Any, Any, Any]:
|
||||||
|
"""动态加载 NapCat 插件测试所需的符号。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Any, Any, Any, Any]:
|
||||||
|
依次返回网关名常量、配置类、插件类和运行时状态管理器类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
constants_module = importlib.import_module("napcat_adapter.constants")
|
||||||
|
config_module = importlib.import_module("napcat_adapter.config")
|
||||||
|
plugin_module = importlib.import_module("napcat_adapter.plugin")
|
||||||
|
runtime_state_module = importlib.import_module("napcat_adapter.runtime_state")
|
||||||
|
return (
|
||||||
|
constants_module.NAPCAT_GATEWAY_NAME,
|
||||||
|
config_module.NapCatServerConfig,
|
||||||
|
plugin_module.NapCatAdapterPlugin,
|
||||||
|
runtime_state_module.NapCatRuntimeStateManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_napcat_plugin_collects_duplex_message_gateway() -> None:
|
||||||
|
"""NapCat 插件应声明新的双工消息网关组件。"""
|
||||||
|
|
||||||
|
napcat_gateway_name, _napcat_server_config, napcat_plugin_cls, _runtime_state_cls = _load_napcat_sdk_symbols()
|
||||||
|
plugin = napcat_plugin_cls()
|
||||||
|
components = plugin.get_components()
|
||||||
|
gateway_components = [
|
||||||
|
component
|
||||||
|
for component in components
|
||||||
|
if component.get("type") == "MESSAGE_GATEWAY"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(gateway_components) == 1
|
||||||
|
gateway_component = gateway_components[0]
|
||||||
|
assert gateway_component["name"] == napcat_gateway_name
|
||||||
|
assert gateway_component["metadata"]["route_type"] == "duplex"
|
||||||
|
assert gateway_component["metadata"]["platform"] == "qq"
|
||||||
|
assert gateway_component["metadata"]["protocol"] == "napcat"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_state_reports_via_gateway_capability() -> None:
|
||||||
|
"""NapCat 运行时状态应通过新的消息网关能力上报。"""
|
||||||
|
|
||||||
|
napcat_gateway_name, napcat_server_config_cls, _napcat_plugin_cls, runtime_state_cls = _load_napcat_sdk_symbols()
|
||||||
|
gateway_capability = _FakeGatewayCapability()
|
||||||
|
runtime_state_manager = runtime_state_cls(
|
||||||
|
gateway_capability=gateway_capability,
|
||||||
|
logger=logging.getLogger("test.napcat_adapter"),
|
||||||
|
gateway_name=napcat_gateway_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
connected = await runtime_state_manager.report_connected(
|
||||||
|
"10001",
|
||||||
|
napcat_server_config_cls(connection_id="primary"),
|
||||||
|
)
|
||||||
|
await runtime_state_manager.report_disconnected()
|
||||||
|
|
||||||
|
assert connected is True
|
||||||
|
assert gateway_capability.calls[0]["gateway_name"] == napcat_gateway_name
|
||||||
|
assert gateway_capability.calls[0]["ready"] is True
|
||||||
|
assert gateway_capability.calls[0]["platform"] == "qq"
|
||||||
|
assert gateway_capability.calls[0]["account_id"] == "10001"
|
||||||
|
assert gateway_capability.calls[0]["scope"] == "primary"
|
||||||
|
assert gateway_capability.calls[1]["gateway_name"] == napcat_gateway_name
|
||||||
|
assert gateway_capability.calls[1]["ready"] is False
|
||||||
|
assert gateway_capability.calls[1]["platform"] == "qq"
|
||||||
@@ -441,8 +441,8 @@ class TestSDK:
|
|||||||
def set_plugin_config(self, config):
|
def set_plugin_config(self, config):
|
||||||
self.configs.append(config)
|
self.configs.append(config)
|
||||||
|
|
||||||
async def on_config_update(self, config, version):
|
async def on_config_update(self, scope, config, version):
|
||||||
self.updates.append((config, version, list(self.configs)))
|
self.updates.append((scope, config, version, list(self.configs)))
|
||||||
|
|
||||||
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
|
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
|
||||||
plugin = DummyPlugin()
|
plugin = DummyPlugin()
|
||||||
@@ -453,14 +453,60 @@ class TestSDK:
|
|||||||
message_type=MessageType.REQUEST,
|
message_type=MessageType.REQUEST,
|
||||||
method="plugin.config_updated",
|
method="plugin.config_updated",
|
||||||
plugin_id="demo_plugin",
|
plugin_id="demo_plugin",
|
||||||
payload={"config_data": {"enabled": True}, "config_version": "v2"},
|
payload={
|
||||||
|
"plugin_id": "demo_plugin",
|
||||||
|
"config_scope": "self",
|
||||||
|
"config_data": {"enabled": True},
|
||||||
|
"config_version": "v2",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await runner._handle_config_updated(envelope)
|
response = await runner._handle_config_updated(envelope)
|
||||||
|
|
||||||
assert response.payload["acknowledged"] is True
|
assert response.payload["acknowledged"] is True
|
||||||
assert plugin.configs == [{"enabled": True}]
|
assert plugin.configs == [{"enabled": True}]
|
||||||
assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])]
|
assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_global_config_update_does_not_override_plugin_config(self):
|
||||||
|
"""bot/model 广播不应覆盖插件自身配置缓存。"""
|
||||||
|
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||||
|
from src.plugin_runtime.runner.runner_main import PluginRunner
|
||||||
|
|
||||||
|
class DummyPlugin:
|
||||||
|
def __init__(self):
|
||||||
|
self.configs = []
|
||||||
|
self.updates = []
|
||||||
|
|
||||||
|
def set_plugin_config(self, config):
|
||||||
|
self.configs.append(config)
|
||||||
|
|
||||||
|
async def on_config_update(self, scope, config, version):
|
||||||
|
self.updates.append((scope, config, version, list(self.configs)))
|
||||||
|
|
||||||
|
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
|
||||||
|
plugin = DummyPlugin()
|
||||||
|
runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin)
|
||||||
|
plugin.set_plugin_config({"plugin_enabled": True})
|
||||||
|
|
||||||
|
envelope = Envelope(
|
||||||
|
request_id=1,
|
||||||
|
message_type=MessageType.REQUEST,
|
||||||
|
method="plugin.config_updated",
|
||||||
|
plugin_id="demo_plugin",
|
||||||
|
payload={
|
||||||
|
"plugin_id": "demo_plugin",
|
||||||
|
"config_scope": "model",
|
||||||
|
"config_data": {"models": []},
|
||||||
|
"config_version": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await runner._handle_config_updated(envelope)
|
||||||
|
|
||||||
|
assert response.payload["acknowledged"] is True
|
||||||
|
assert plugin.configs == [{"plugin_enabled": True}]
|
||||||
|
assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch):
|
async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch):
|
||||||
@@ -911,6 +957,120 @@ class TestDependencyResolution:
|
|||||||
assert loader.failed_plugins == {}
|
assert loader.failed_plugins == {}
|
||||||
assert loaded[0].instance.answer() == 42
|
assert loaded[0].instance.answer() == 42
|
||||||
|
|
||||||
|
def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path):
|
||||||
|
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||||
|
|
||||||
|
plugin_root = tmp_path / "plugins"
|
||||||
|
plugin_root.mkdir()
|
||||||
|
plugin_dir = plugin_root / "demo_plugin"
|
||||||
|
plugin_dir.mkdir()
|
||||||
|
|
||||||
|
(plugin_dir / "_manifest.json").write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"name": "demo_plugin",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "demo",
|
||||||
|
"author": "tester",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(plugin_dir / "plugin.py").write_text(
|
||||||
|
"from maibot_sdk import MaiBotPlugin\n\n"
|
||||||
|
"class DemoPlugin(MaiBotPlugin):\n"
|
||||||
|
" async def on_load(self):\n"
|
||||||
|
" pass\n\n"
|
||||||
|
" async def on_unload(self):\n"
|
||||||
|
" pass\n\n"
|
||||||
|
"def create_plugin():\n"
|
||||||
|
" return DemoPlugin()\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = PluginLoader()
|
||||||
|
loaded = loader.discover_and_load([str(plugin_root)])
|
||||||
|
|
||||||
|
assert loaded == []
|
||||||
|
assert "demo_plugin" in loader.failed_plugins
|
||||||
|
assert "on_config_update" in loader.failed_plugins["demo_plugin"]
|
||||||
|
|
||||||
|
def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path):
|
||||||
|
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||||
|
|
||||||
|
plugin_root = tmp_path / "plugins"
|
||||||
|
plugin_root.mkdir()
|
||||||
|
plugin_dir = plugin_root / "demo_plugin"
|
||||||
|
plugin_dir.mkdir()
|
||||||
|
|
||||||
|
(plugin_dir / "_manifest.json").write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"name": "demo_plugin",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "demo",
|
||||||
|
"author": "tester",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(plugin_dir / "plugin.py").write_text(
|
||||||
|
"from maibot_sdk import MaiBotPlugin\n\n"
|
||||||
|
"class DemoPlugin(MaiBotPlugin):\n"
|
||||||
|
" async def on_unload(self):\n"
|
||||||
|
" pass\n\n"
|
||||||
|
" async def on_config_update(self, scope, config_data, version):\n"
|
||||||
|
" pass\n\n"
|
||||||
|
"def create_plugin():\n"
|
||||||
|
" return DemoPlugin()\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = PluginLoader()
|
||||||
|
loaded = loader.discover_and_load([str(plugin_root)])
|
||||||
|
|
||||||
|
assert loaded == []
|
||||||
|
assert "demo_plugin" in loader.failed_plugins
|
||||||
|
assert "on_load" in loader.failed_plugins["demo_plugin"]
|
||||||
|
|
||||||
|
def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path):
|
||||||
|
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||||
|
|
||||||
|
plugin_root = tmp_path / "plugins"
|
||||||
|
plugin_root.mkdir()
|
||||||
|
plugin_dir = plugin_root / "demo_plugin"
|
||||||
|
plugin_dir.mkdir()
|
||||||
|
|
||||||
|
(plugin_dir / "_manifest.json").write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"name": "demo_plugin",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "demo",
|
||||||
|
"author": "tester",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(plugin_dir / "plugin.py").write_text(
|
||||||
|
"from maibot_sdk import MaiBotPlugin\n\n"
|
||||||
|
"class DemoPlugin(MaiBotPlugin):\n"
|
||||||
|
" async def on_load(self):\n"
|
||||||
|
" pass\n\n"
|
||||||
|
" async def on_config_update(self, scope, config_data, version):\n"
|
||||||
|
" pass\n\n"
|
||||||
|
"def create_plugin():\n"
|
||||||
|
" return DemoPlugin()\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = PluginLoader()
|
||||||
|
loaded = loader.discover_and_load([str(plugin_root)])
|
||||||
|
|
||||||
|
assert loaded == []
|
||||||
|
assert "demo_plugin" in loader.failed_plugins
|
||||||
|
assert "on_unload" in loader.failed_plugins["demo_plugin"]
|
||||||
|
|
||||||
def test_isolate_sys_path_preserves_plugin_dirs(self):
|
def test_isolate_sys_path_preserves_plugin_dirs(self):
|
||||||
from src.plugin_runtime.runner import runner_main
|
from src.plugin_runtime.runner import runner_main
|
||||||
|
|
||||||
@@ -2299,9 +2459,10 @@ class TestIntegration:
|
|||||||
assert refresh_calls == [True]
|
assert refresh_calls == [True]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
|
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path):
|
||||||
from src.plugin_runtime import integration as integration_module
|
from src.plugin_runtime import integration as integration_module
|
||||||
from src.config.file_watcher import FileChange
|
from src.config.file_watcher import FileChange
|
||||||
|
import json
|
||||||
|
|
||||||
builtin_root = tmp_path / "src" / "plugins" / "built_in"
|
builtin_root = tmp_path / "src" / "plugins" / "built_in"
|
||||||
thirdparty_root = tmp_path / "plugins"
|
thirdparty_root = tmp_path / "plugins"
|
||||||
@@ -2311,6 +2472,10 @@ class TestIntegration:
|
|||||||
beta_dir.mkdir(parents=True)
|
beta_dir.mkdir(parents=True)
|
||||||
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
|
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
|
||||||
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
|
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
|
||||||
|
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||||
|
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
|
||||||
|
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
|
||||||
|
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
|
||||||
|
|
||||||
monkeypatch.chdir(tmp_path)
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
@@ -2318,31 +2483,95 @@ class TestIntegration:
|
|||||||
def __init__(self, plugin_dirs, plugins):
|
def __init__(self, plugin_dirs, plugins):
|
||||||
self._plugin_dirs = plugin_dirs
|
self._plugin_dirs = plugin_dirs
|
||||||
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
|
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
|
||||||
self.reload_calls = []
|
self.config_updates = []
|
||||||
|
|
||||||
async def reload_plugin(self, plugin_id, reason="manual"):
|
async def notify_plugin_config_updated(
|
||||||
self.reload_calls.append((plugin_id, reason))
|
self,
|
||||||
|
plugin_id,
|
||||||
|
config_data,
|
||||||
|
config_version="",
|
||||||
|
config_scope="self",
|
||||||
|
):
|
||||||
|
self.config_updates.append((plugin_id, config_data, config_version, config_scope))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
manager = integration_module.PluginRuntimeManager()
|
manager = integration_module.PluginRuntimeManager()
|
||||||
manager._started = True
|
manager._started = True
|
||||||
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
|
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
|
||||||
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
|
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
|
||||||
refresh_calls = []
|
|
||||||
|
|
||||||
def fake_refresh() -> None:
|
|
||||||
refresh_calls.append(True)
|
|
||||||
|
|
||||||
manager._refresh_plugin_config_watch_subscriptions = fake_refresh
|
|
||||||
|
|
||||||
await manager._handle_plugin_config_changes(
|
await manager._handle_plugin_config_changes(
|
||||||
"alpha",
|
"alpha",
|
||||||
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
|
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
|
assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "", "self")]
|
||||||
assert manager._third_party_supervisor.reload_calls == []
|
assert manager._third_party_supervisor.config_updates == []
|
||||||
assert refresh_calls == [True]
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch):
|
||||||
|
from src.plugin_runtime import integration as integration_module
|
||||||
|
|
||||||
|
class FakeRegistration:
|
||||||
|
def __init__(self, subscriptions):
|
||||||
|
self.config_reload_subscriptions = subscriptions
|
||||||
|
|
||||||
|
class FakeSupervisor:
|
||||||
|
def __init__(self, registrations):
|
||||||
|
self._registered_plugins = registrations
|
||||||
|
self.config_updates = []
|
||||||
|
|
||||||
|
def get_config_reload_subscribers(self, scope):
|
||||||
|
matched_plugins = []
|
||||||
|
for plugin_id, registration in self._registered_plugins.items():
|
||||||
|
if scope in registration.config_reload_subscriptions:
|
||||||
|
matched_plugins.append(plugin_id)
|
||||||
|
return matched_plugins
|
||||||
|
|
||||||
|
async def notify_plugin_config_updated(
|
||||||
|
self,
|
||||||
|
plugin_id,
|
||||||
|
config_data,
|
||||||
|
config_version="",
|
||||||
|
config_scope="self",
|
||||||
|
):
|
||||||
|
self.config_updates.append((plugin_id, config_data, config_version, config_scope))
|
||||||
|
return True
|
||||||
|
|
||||||
|
fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
integration_module.config_manager,
|
||||||
|
"get_global_config",
|
||||||
|
lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
integration_module.config_manager,
|
||||||
|
"get_model_config",
|
||||||
|
lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = integration_module.PluginRuntimeManager()
|
||||||
|
manager._started = True
|
||||||
|
manager._builtin_supervisor = FakeSupervisor(
|
||||||
|
{
|
||||||
|
"alpha": FakeRegistration(["bot"]),
|
||||||
|
"beta": FakeRegistration([]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
manager._third_party_supervisor = FakeSupervisor(
|
||||||
|
{
|
||||||
|
"gamma": FakeRegistration(["model"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await manager._handle_main_config_reload(["bot", "model"])
|
||||||
|
|
||||||
|
assert manager._builtin_supervisor.config_updates == [
|
||||||
|
("alpha", {"bot": {"name": "MaiBot"}}, "", "bot")
|
||||||
|
]
|
||||||
|
assert manager._third_party_supervisor.config_updates == [
|
||||||
|
("gamma", {"models": [{"name": "demo"}]}, "", "model")
|
||||||
|
]
|
||||||
|
|
||||||
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
|
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
|
||||||
from src.plugin_runtime import integration as integration_module
|
from src.plugin_runtime import integration as integration_module
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
@@ -61,6 +62,7 @@ MODEL_CONFIG_VERSION: str = "1.12.0"
|
|||||||
logger = get_logger("config")
|
logger = get_logger("config")
|
||||||
|
|
||||||
T = TypeVar("T", bound="ConfigBase")
|
T = TypeVar("T", bound="ConfigBase")
|
||||||
|
ConfigReloadCallback = Callable[[Sequence[str]], object] | Callable[[], object]
|
||||||
|
|
||||||
|
|
||||||
class Config(ConfigBase):
|
class Config(ConfigBase):
|
||||||
@@ -190,7 +192,7 @@ class ConfigManager:
|
|||||||
self.global_config: Config | None = None
|
self.global_config: Config | None = None
|
||||||
self.model_config: ModelConfig | None = None
|
self.model_config: ModelConfig | None = None
|
||||||
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
self._reload_lock: asyncio.Lock = asyncio.Lock()
|
||||||
self._reload_callbacks: list[Callable[[], object]] = []
|
self._reload_callbacks: list[ConfigReloadCallback] = []
|
||||||
self._file_watcher: FileWatcher | None = None
|
self._file_watcher: FileWatcher | None = None
|
||||||
self._file_watcher_subscription_id: str | None = None
|
self._file_watcher_subscription_id: str | None = None
|
||||||
self._hot_reload_min_interval_s: float = 1.0
|
self._hot_reload_min_interval_s: float = 1.0
|
||||||
@@ -226,16 +228,125 @@ class ConfigManager:
|
|||||||
raise RuntimeError(t("config.model_not_initialized"))
|
raise RuntimeError(t("config.model_not_initialized"))
|
||||||
return self.model_config
|
return self.model_config
|
||||||
|
|
||||||
def register_reload_callback(self, callback: Callable[[], object]) -> None:
|
def register_reload_callback(self, callback: ConfigReloadCallback) -> None:
|
||||||
|
"""注册配置热重载回调。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 配置热重载回调。允许无参回调,也允许接收
|
||||||
|
``Sequence[str]`` 类型的变更范围列表。
|
||||||
|
"""
|
||||||
|
|
||||||
self._reload_callbacks.append(callback)
|
self._reload_callbacks.append(callback)
|
||||||
|
|
||||||
def unregister_reload_callback(self, callback: Callable[[], object]) -> None:
|
def unregister_reload_callback(self, callback: ConfigReloadCallback) -> None:
|
||||||
|
"""注销配置热重载回调。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 先前注册过的回调对象。
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._reload_callbacks.remove(callback)
|
self._reload_callbacks.remove(callback)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def reload_config(self) -> bool:
|
@staticmethod
|
||||||
|
def _normalize_changed_scopes(changed_scopes: Sequence[str] | None) -> tuple[str, ...]:
|
||||||
|
"""规范化配置变更范围列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changed_scopes: 原始配置变更范围。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, ...]: 去重后的配置变更范围元组。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not changed_scopes:
|
||||||
|
return ("bot", "model")
|
||||||
|
|
||||||
|
normalized_scopes: list[str] = []
|
||||||
|
for scope in changed_scopes:
|
||||||
|
normalized_scope = str(scope or "").strip().lower()
|
||||||
|
if normalized_scope not in {"bot", "model"}:
|
||||||
|
continue
|
||||||
|
if normalized_scope not in normalized_scopes:
|
||||||
|
normalized_scopes.append(normalized_scope)
|
||||||
|
return tuple(normalized_scopes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_changed_scopes(changes: Sequence[FileChange]) -> tuple[str, ...]:
|
||||||
|
"""根据文件变更列表推断配置变更范围。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changes: 文件监听器返回的变更列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, ...]: 命中的配置变更范围元组。
|
||||||
|
"""
|
||||||
|
|
||||||
|
changed_scopes: list[str] = []
|
||||||
|
for change in changes:
|
||||||
|
file_name = change.path.name
|
||||||
|
if file_name == "bot_config.toml" and "bot" not in changed_scopes:
|
||||||
|
changed_scopes.append("bot")
|
||||||
|
if file_name == "model_config.toml" and "model" not in changed_scopes:
|
||||||
|
changed_scopes.append("model")
|
||||||
|
return tuple(changed_scopes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _callback_accepts_scopes(callback: ConfigReloadCallback) -> bool:
|
||||||
|
"""判断回调是否接收配置变更范围参数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 待检测的回调对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若回调可接收一个位置参数或可变位置参数,则返回 ``True``。
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
parameters = inspect.signature(callback).parameters.values()
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
positional_params = {
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
}
|
||||||
|
for parameter in parameters:
|
||||||
|
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||||
|
return True
|
||||||
|
if parameter.kind in positional_params:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _invoke_reload_callback(
|
||||||
|
self,
|
||||||
|
callback: ConfigReloadCallback,
|
||||||
|
changed_scopes: Sequence[str],
|
||||||
|
) -> None:
|
||||||
|
"""执行单个配置热重载回调。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 要执行的回调对象。
|
||||||
|
changed_scopes: 本次热重载命中的配置范围。
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = callback(changed_scopes) if self._callback_accepts_scopes(callback) else callback()
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
|
||||||
|
async def reload_config(self, changed_scopes: Sequence[str] | None = None) -> bool:
|
||||||
|
"""重新加载主配置和模型配置。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changed_scopes: 本次触发热重载的配置范围。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否重载成功。
|
||||||
|
"""
|
||||||
|
|
||||||
|
normalized_scopes = self._normalize_changed_scopes(changed_scopes)
|
||||||
async with self._reload_lock:
|
async with self._reload_lock:
|
||||||
try:
|
try:
|
||||||
global_config_new, global_updated = load_config_from_file(
|
global_config_new, global_updated = load_config_from_file(
|
||||||
@@ -265,9 +376,7 @@ class ConfigManager:
|
|||||||
|
|
||||||
for callback in list(self._reload_callbacks):
|
for callback in list(self._reload_callbacks):
|
||||||
try:
|
try:
|
||||||
result = callback()
|
await self._invoke_reload_callback(callback, normalized_scopes)
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
await result
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(t("config.reload_callback_failed", error=exc))
|
logger.warning(t("config.reload_callback_failed", error=exc))
|
||||||
return True
|
return True
|
||||||
@@ -312,6 +421,12 @@ class ConfigManager:
|
|||||||
self._file_watcher = None
|
self._file_watcher = None
|
||||||
|
|
||||||
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
|
async def _handle_file_changes(self, changes: Sequence[FileChange]) -> None:
|
||||||
|
"""处理主配置与模型配置文件变更。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changes: 当前批次收集到的文件变更列表。
|
||||||
|
"""
|
||||||
|
|
||||||
if not changes:
|
if not changes:
|
||||||
return
|
return
|
||||||
now_monotonic = asyncio.get_running_loop().time()
|
now_monotonic = asyncio.get_running_loop().time()
|
||||||
@@ -321,7 +436,11 @@ class ConfigManager:
|
|||||||
self._last_hot_reload_monotonic = now_monotonic
|
self._last_hot_reload_monotonic = now_monotonic
|
||||||
logger.info(t("config.file_change_detected"))
|
logger.info(t("config.file_change_detected"))
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.reload_config(), timeout=self._hot_reload_timeout_s)
|
changed_scopes = self._resolve_changed_scopes(changes)
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.reload_config(changed_scopes=changed_scopes),
|
||||||
|
timeout=self._hot_reload_timeout_s,
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))
|
logger.error(t("config.reload_timeout", timeout_seconds=self._hot_reload_timeout_s))
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,15 @@
|
|||||||
|
|
||||||
功能:
|
功能:
|
||||||
1. 定期随机选取指定数量的表达方式
|
1. 定期随机选取指定数量的表达方式
|
||||||
2. 使用LLM进行评估
|
2. 使用 LLM 进行评估
|
||||||
3. 通过评估的:rejected=0, checked=1
|
3. 通过评估的:rejected=0, checked=1
|
||||||
4. 未通过评估的:rejected=1, checked=1
|
4. 未通过评估的:rejected=1, checked=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
@@ -146,7 +146,8 @@ class ExpressionAutoCheckTask(AsyncTask):
|
|||||||
选中的表达方式列表
|
选中的表达方式列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
# 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
|
||||||
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(Expression)
|
statement = select(Expression)
|
||||||
all_expressions = session.exec(statement).all()
|
all_expressions = session.exec(statement).all()
|
||||||
|
|
||||||
|
|||||||
@@ -399,6 +399,7 @@ class PluginRunnerSupervisor:
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
config_data: Optional[Dict[str, Any]] = None,
|
config_data: Optional[Dict[str, Any]] = None,
|
||||||
config_version: str = "",
|
config_version: str = "",
|
||||||
|
config_scope: str = "self",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向 Runner 推送插件配置更新。
|
"""向 Runner 推送插件配置更新。
|
||||||
|
|
||||||
@@ -406,12 +407,14 @@ class PluginRunnerSupervisor:
|
|||||||
plugin_id: 目标插件 ID。
|
plugin_id: 目标插件 ID。
|
||||||
config_data: 配置内容。
|
config_data: 配置内容。
|
||||||
config_version: 配置版本号。
|
config_version: 配置版本号。
|
||||||
|
config_scope: 配置变更范围。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 请求是否成功送达并被 Runner 接受。
|
bool: 请求是否成功送达并被 Runner 接受。
|
||||||
"""
|
"""
|
||||||
payload = ConfigUpdatedPayload(
|
payload = ConfigUpdatedPayload(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
|
config_scope=config_scope,
|
||||||
config_version=config_version,
|
config_version=config_version,
|
||||||
config_data=config_data or {},
|
config_data=config_data or {},
|
||||||
)
|
)
|
||||||
@@ -428,6 +431,22 @@ class PluginRunnerSupervisor:
|
|||||||
|
|
||||||
return bool(response.payload.get("acknowledged", False))
|
return bool(response.payload.get("acknowledged", False))
|
||||||
|
|
||||||
|
def get_config_reload_subscribers(self, scope: str) -> List[str]:
|
||||||
|
"""返回订阅指定全局配置广播的插件列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 已声明订阅该范围的插件 ID 列表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
matched_plugins: List[str] = []
|
||||||
|
for plugin_id, registration in self._registered_plugins.items():
|
||||||
|
if scope in registration.config_reload_subscriptions:
|
||||||
|
matched_plugins.append(plugin_id)
|
||||||
|
return matched_plugins
|
||||||
|
|
||||||
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
|
async def _wait_for_runner_connection(self, timeout_sec: float) -> None:
|
||||||
"""等待 Runner 建立 RPC 连接。
|
"""等待 Runner 建立 RPC 连接。
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import json
|
|||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import config_manager
|
||||||
from src.config.file_watcher import FileChange, FileWatcher
|
from src.config.file_watcher import FileChange, FileWatcher
|
||||||
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
|
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
|
||||||
from src.plugin_runtime.capabilities import (
|
from src.plugin_runtime.capabilities import (
|
||||||
@@ -69,6 +69,8 @@ class PluginRuntimeManager(
|
|||||||
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
||||||
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
||||||
self._plugin_path_cache: Dict[str, Path] = {}
|
self._plugin_path_cache: Dict[str, Path] = {}
|
||||||
|
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
|
||||||
|
self._config_reload_callback_registered: bool = False
|
||||||
|
|
||||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||||
@@ -108,7 +110,7 @@ class PluginRuntimeManager(
|
|||||||
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
|
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
|
||||||
return
|
return
|
||||||
|
|
||||||
_cfg = global_config.plugin_runtime
|
_cfg = config_manager.get_global_config().plugin_runtime
|
||||||
if not _cfg.enabled:
|
if not _cfg.enabled:
|
||||||
logger.info("插件运行时已在配置中禁用,跳过启动")
|
logger.info("插件运行时已在配置中禁用,跳过启动")
|
||||||
return
|
return
|
||||||
@@ -166,11 +168,16 @@ class PluginRuntimeManager(
|
|||||||
await self._third_party_supervisor.start()
|
await self._third_party_supervisor.start()
|
||||||
started_supervisors.append(self._third_party_supervisor)
|
started_supervisors.append(self._third_party_supervisor)
|
||||||
await self._start_plugin_file_watcher()
|
await self._start_plugin_file_watcher()
|
||||||
|
config_manager.register_reload_callback(self._config_reload_callback)
|
||||||
|
self._config_reload_callback_registered = True
|
||||||
self._started = True
|
self._started = True
|
||||||
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}")
|
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or '无'}, 第三方: {third_party_dirs or '无'}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
||||||
await self._stop_plugin_file_watcher()
|
await self._stop_plugin_file_watcher()
|
||||||
|
if self._config_reload_callback_registered:
|
||||||
|
config_manager.unregister_reload_callback(self._config_reload_callback)
|
||||||
|
self._config_reload_callback_registered = False
|
||||||
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
||||||
platform_io_manager.clear_inbound_dispatcher()
|
platform_io_manager.clear_inbound_dispatcher()
|
||||||
try:
|
try:
|
||||||
@@ -188,6 +195,9 @@ class PluginRuntimeManager(
|
|||||||
|
|
||||||
platform_io_manager = get_platform_io_manager()
|
platform_io_manager = get_platform_io_manager()
|
||||||
await self._stop_plugin_file_watcher()
|
await self._stop_plugin_file_watcher()
|
||||||
|
if self._config_reload_callback_registered:
|
||||||
|
config_manager.unregister_reload_callback(self._config_reload_callback)
|
||||||
|
self._config_reload_callback_registered = False
|
||||||
|
|
||||||
coroutines: List[Coroutine[Any, Any, None]] = []
|
coroutines: List[Coroutine[Any, Any, None]] = []
|
||||||
if self._builtin_supervisor:
|
if self._builtin_supervisor:
|
||||||
@@ -233,6 +243,7 @@ class PluginRuntimeManager(
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
config_data: Optional[Dict[str, Any]] = None,
|
config_data: Optional[Dict[str, Any]] = None,
|
||||||
config_version: str = "",
|
config_version: str = "",
|
||||||
|
config_scope: str = "self",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向拥有该插件的 Supervisor 推送配置更新事件。
|
"""向拥有该插件的 Supervisor 推送配置更新事件。
|
||||||
|
|
||||||
@@ -240,6 +251,7 @@ class PluginRuntimeManager(
|
|||||||
plugin_id: 插件 ID
|
plugin_id: 插件 ID
|
||||||
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
|
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
|
||||||
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
|
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
|
||||||
|
config_scope: 配置变更范围。
|
||||||
"""
|
"""
|
||||||
if not self._started:
|
if not self._started:
|
||||||
return False
|
return False
|
||||||
@@ -258,12 +270,67 @@ class PluginRuntimeManager(
|
|||||||
if config_data is not None
|
if config_data is not None
|
||||||
else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
else self._load_plugin_config_for_supervisor(sv, plugin_id)
|
||||||
)
|
)
|
||||||
await sv.notify_plugin_config_updated(
|
return await sv.notify_plugin_config_updated(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
config_data=config_payload,
|
config_data=config_payload,
|
||||||
config_version=config_version,
|
config_version=config_version,
|
||||||
|
config_scope=config_scope,
|
||||||
)
|
)
|
||||||
return True
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
|
||||||
|
"""规范化配置热重载范围列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changed_scopes: 原始配置热重载范围列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, ...]: 去重后的有效配置范围元组。
|
||||||
|
"""
|
||||||
|
|
||||||
|
normalized_scopes: list[str] = []
|
||||||
|
for scope in changed_scopes:
|
||||||
|
normalized_scope = str(scope or "").strip().lower()
|
||||||
|
if normalized_scope not in {"bot", "model"}:
|
||||||
|
continue
|
||||||
|
if normalized_scope not in normalized_scopes:
|
||||||
|
normalized_scopes.append(normalized_scope)
|
||||||
|
return tuple(normalized_scopes)
|
||||||
|
|
||||||
|
async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
|
||||||
|
"""向订阅指定范围的插件广播配置热重载。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
|
||||||
|
config_data: 最新配置数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
for supervisor in self.supervisors:
|
||||||
|
for plugin_id in supervisor.get_config_reload_subscribers(scope):
|
||||||
|
delivered = await supervisor.notify_plugin_config_updated(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
config_data=config_data,
|
||||||
|
config_version="",
|
||||||
|
config_scope=scope,
|
||||||
|
)
|
||||||
|
if not delivered:
|
||||||
|
logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
|
||||||
|
|
||||||
|
async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
|
||||||
|
"""处理 bot/model 主配置热重载广播。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
changed_scopes: 本次热重载命中的配置范围列表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self._started:
|
||||||
|
return
|
||||||
|
|
||||||
|
normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
|
||||||
|
if "bot" in normalized_scopes:
|
||||||
|
await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump())
|
||||||
|
if "model" in normalized_scopes:
|
||||||
|
await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump())
|
||||||
|
|
||||||
# ─── 事件桥接 ──────────────────────────────────────────────
|
# ─── 事件桥接 ──────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -612,16 +679,12 @@ class PluginRuntimeManager(
|
|||||||
return None if plugin_path is None else plugin_path / "config.toml"
|
return None if plugin_path is None else plugin_path / "config.toml"
|
||||||
|
|
||||||
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
|
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
|
||||||
"""处理单个插件配置文件变化,并精确重载目标插件。
|
"""处理单个插件配置文件变化,并定向派发自配置热更新。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plugin_id: 发生配置变更的插件 ID。
|
plugin_id: 发生配置变更的插件 ID。
|
||||||
changes: 当前批次收集到的配置文件变更列表。
|
changes: 当前批次收集到的配置文件变更列表。
|
||||||
|
|
||||||
Notes:
|
|
||||||
这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。
|
|
||||||
这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行
|
|
||||||
``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。
|
|
||||||
"""
|
"""
|
||||||
if not self._started or not changes:
|
if not self._started or not changes:
|
||||||
return
|
return
|
||||||
@@ -636,15 +699,15 @@ class PluginRuntimeManager(
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._load_plugin_config_for_supervisor(supervisor, plugin_id)
|
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
|
||||||
reload_success = await supervisor.reload_plugin(
|
delivered = await supervisor.notify_plugin_config_updated(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
reason="config_file_changed",
|
config_data=config_payload,
|
||||||
|
config_version="",
|
||||||
|
config_scope="self",
|
||||||
)
|
)
|
||||||
if reload_success:
|
if not delivered:
|
||||||
self._refresh_plugin_config_watch_subscriptions()
|
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
|
||||||
else:
|
|
||||||
logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败")
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
|
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
|
||||||
|
|
||||||
@@ -652,8 +715,8 @@ class PluginRuntimeManager(
|
|||||||
"""处理插件源码相关变化。
|
"""处理插件源码相关变化。
|
||||||
|
|
||||||
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
|
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
|
||||||
单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成
|
单独的 per-plugin watcher 处理,并定向派发给目标插件的
|
||||||
不必要的跨插件 reload。
|
``on_config_update()``,避免放大成不必要的跨插件 reload。
|
||||||
"""
|
"""
|
||||||
if not self._started or not changes:
|
if not self._started or not changes:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -29,6 +29,14 @@ class MessageType(str, Enum):
|
|||||||
BROADCAST = "broadcast"
|
BROADCAST = "broadcast"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigReloadScope(str, Enum):
|
||||||
|
"""配置热重载范围。"""
|
||||||
|
|
||||||
|
SELF = "self"
|
||||||
|
BOT = "bot"
|
||||||
|
MODEL = "model"
|
||||||
|
|
||||||
|
|
||||||
# ====== 请求 ID 生成器 ======
|
# ====== 请求 ID 生成器 ======
|
||||||
class RequestIdGenerator:
|
class RequestIdGenerator:
|
||||||
"""单调递增 int64 请求 ID 生成器"""
|
"""单调递增 int64 请求 ID 生成器"""
|
||||||
@@ -158,6 +166,8 @@ class RegisterPluginPayload(BaseModel):
|
|||||||
"""组件列表"""
|
"""组件列表"""
|
||||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||||
"""所需能力列表"""
|
"""所需能力列表"""
|
||||||
|
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
|
||||||
|
"""订阅的全局配置热重载范围"""
|
||||||
|
|
||||||
|
|
||||||
class BootstrapPluginPayload(BaseModel):
|
class BootstrapPluginPayload(BaseModel):
|
||||||
@@ -236,6 +246,8 @@ class ConfigUpdatedPayload(BaseModel):
|
|||||||
|
|
||||||
plugin_id: str = Field(description="插件 ID")
|
plugin_id: str = Field(description="插件 ID")
|
||||||
"""插件 ID"""
|
"""插件 ID"""
|
||||||
|
config_scope: ConfigReloadScope = Field(description="配置变更范围")
|
||||||
|
"""配置变更范围"""
|
||||||
config_version: str = Field(description="新配置版本")
|
config_version: str = Field(description="新配置版本")
|
||||||
"""新配置版本"""
|
"""新配置版本"""
|
||||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||||
|
|||||||
@@ -403,6 +403,7 @@ class PluginLoader:
|
|||||||
create_plugin = getattr(module, "create_plugin", None)
|
create_plugin = getattr(module, "create_plugin", None)
|
||||||
if create_plugin is not None:
|
if create_plugin is not None:
|
||||||
instance = create_plugin()
|
instance = create_plugin()
|
||||||
|
self._validate_sdk_plugin_contract(plugin_id, instance)
|
||||||
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
|
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
|
||||||
return PluginMeta(
|
return PluginMeta(
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
@@ -432,6 +433,35 @@ class PluginLoader:
|
|||||||
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
|
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None:
|
||||||
|
"""校验 SDK 插件的基础契约。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 当前插件 ID。
|
||||||
|
instance: ``create_plugin()`` 返回的插件实例。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from maibot_sdk.plugin import MaiBotPlugin
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(instance, MaiBotPlugin):
|
||||||
|
return
|
||||||
|
|
||||||
|
if type(instance).on_load is MaiBotPlugin.on_load:
|
||||||
|
raise TypeError(f"插件 {plugin_id} 必须实现 on_load()")
|
||||||
|
if type(instance).on_unload is MaiBotPlugin.on_unload:
|
||||||
|
raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()")
|
||||||
|
if type(instance).on_config_update is MaiBotPlugin.on_config_update:
|
||||||
|
raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()")
|
||||||
|
|
||||||
|
instance.get_config_reload_subscriptions()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _temporary_sys_path_entry(path: Path) -> Iterator[None]:
|
def _temporary_sys_path_entry(path: Path) -> Iterator[None]:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIR
|
|||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
BootstrapPluginPayload,
|
BootstrapPluginPayload,
|
||||||
ComponentDeclaration,
|
ComponentDeclaration,
|
||||||
|
ConfigUpdatedPayload,
|
||||||
Envelope,
|
Envelope,
|
||||||
HealthPayload,
|
HealthPayload,
|
||||||
InvokePayload,
|
InvokePayload,
|
||||||
@@ -342,6 +343,7 @@ class PluginRunner:
|
|||||||
"""
|
"""
|
||||||
# 收集插件组件声明
|
# 收集插件组件声明
|
||||||
components: List[ComponentDeclaration] = []
|
components: List[ComponentDeclaration] = []
|
||||||
|
config_reload_subscriptions: List[str] = []
|
||||||
instance = meta.instance
|
instance = meta.instance
|
||||||
|
|
||||||
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
|
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
|
||||||
@@ -355,12 +357,15 @@ class PluginRunner:
|
|||||||
)
|
)
|
||||||
for comp_info in instance.get_components()
|
for comp_info in instance.get_components()
|
||||||
)
|
)
|
||||||
|
if hasattr(instance, "get_config_reload_subscriptions"):
|
||||||
|
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
|
||||||
|
|
||||||
reg_payload = RegisterPluginPayload(
|
reg_payload = RegisterPluginPayload(
|
||||||
plugin_id=meta.plugin_id,
|
plugin_id=meta.plugin_id,
|
||||||
plugin_version=meta.version,
|
plugin_version=meta.version,
|
||||||
components=components,
|
components=components,
|
||||||
capabilities_required=meta.capabilities_required,
|
capabilities_required=meta.capabilities_required,
|
||||||
|
config_reload_subscriptions=config_reload_subscriptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -911,16 +916,26 @@ class PluginRunner:
|
|||||||
return envelope.make_response(payload={"acknowledged": True})
|
return envelope.make_response(payload={"acknowledged": True})
|
||||||
|
|
||||||
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
|
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
|
||||||
"""处理配置更新事件"""
|
"""处理配置更新事件。"""
|
||||||
|
try:
|
||||||
|
payload = ConfigUpdatedPayload.model_validate(envelope.payload)
|
||||||
|
except Exception as exc:
|
||||||
|
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||||
|
|
||||||
plugin_id = envelope.plugin_id
|
plugin_id = envelope.plugin_id
|
||||||
if meta := self._loader.get_plugin(plugin_id):
|
if meta := self._loader.get_plugin(plugin_id):
|
||||||
try:
|
try:
|
||||||
config_data = envelope.payload.get("config_data", {})
|
config_scope = payload.config_scope.value
|
||||||
config_version = envelope.payload.get("config_version", "")
|
if config_scope == "self":
|
||||||
self._apply_plugin_config(meta, config_data=config_data)
|
self._apply_plugin_config(meta, config_data=payload.config_data)
|
||||||
if hasattr(meta.instance, "on_config_update"):
|
if not hasattr(meta.instance, "on_config_update"):
|
||||||
ret = meta.instance.on_config_update(config_data, config_version)
|
raise AttributeError("插件缺少 on_config_update() 实现")
|
||||||
# 兼容同步和异步的 on_config_update 实现
|
|
||||||
|
ret = meta.instance.on_config_update(
|
||||||
|
config_scope,
|
||||||
|
payload.config_data,
|
||||||
|
payload.config_version,
|
||||||
|
)
|
||||||
if asyncio.iscoroutine(ret):
|
if asyncio.iscoroutine(ret):
|
||||||
await ret
|
await ret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -3,11 +3,11 @@
|
|||||||
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
|
根据聊天上下文的情感,使用 LLM 选择并发送合适的表情包。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import random
|
from maibot_sdk import Action, MaiBotPlugin
|
||||||
|
|
||||||
from maibot_sdk import MaiBotPlugin, Action
|
|
||||||
from maibot_sdk.types import ActivationType
|
from maibot_sdk.types import ActivationType
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class EmojiPlugin(MaiBotPlugin):
|
class EmojiPlugin(MaiBotPlugin):
|
||||||
"""表情包插件"""
|
"""表情包插件"""
|
||||||
@@ -95,10 +95,35 @@ class EmojiPlugin(MaiBotPlugin):
|
|||||||
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
|
return True, f"成功发送表情包:[表情包:{chosen_emotion}]"
|
||||||
return False, "发送表情包失败"
|
return False, "发送表情包失败"
|
||||||
|
|
||||||
async def on_load(self):
|
async def on_load(self) -> None:
|
||||||
|
"""处理插件加载。"""
|
||||||
|
|
||||||
# 从插件配置读取 emoji_chance 来覆盖默认概率
|
# 从插件配置读取 emoji_chance 来覆盖默认概率
|
||||||
await self.ctx.config.get("emoji.emoji_chance")
|
await self.ctx.config.get("emoji.emoji_chance")
|
||||||
|
|
||||||
|
async def on_unload(self) -> None:
|
||||||
|
"""处理插件卸载。"""
|
||||||
|
|
||||||
|
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||||
|
"""处理配置热重载事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围。
|
||||||
|
config_data: 最新配置数据。
|
||||||
|
version: 配置版本号。
|
||||||
|
"""
|
||||||
|
|
||||||
|
del config_data
|
||||||
|
del version
|
||||||
|
if scope == "self":
|
||||||
|
await self.ctx.config.get("emoji.emoji_chance")
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin() -> EmojiPlugin:
|
||||||
|
"""创建 Emoji 插件实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmojiPlugin: 新的 Emoji 插件实例。
|
||||||
|
"""
|
||||||
|
|
||||||
def create_plugin():
|
|
||||||
return EmojiPlugin()
|
return EmojiPlugin()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
通过 /pm 命令管理插件和组件的生命周期。
|
通过 /pm 命令管理插件和组件的生命周期。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from maibot_sdk import MaiBotPlugin, Command
|
from maibot_sdk import Command, MaiBotPlugin
|
||||||
|
|
||||||
|
|
||||||
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
|
_VALID_COMPONENT_TYPES = ("action", "command", "event_handler")
|
||||||
@@ -44,6 +44,12 @@ HELP_COMPONENT = (
|
|||||||
class PluginManagementPlugin(MaiBotPlugin):
|
class PluginManagementPlugin(MaiBotPlugin):
|
||||||
"""插件和组件管理插件"""
|
"""插件和组件管理插件"""
|
||||||
|
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""处理插件加载。"""
|
||||||
|
|
||||||
|
async def on_unload(self) -> None:
|
||||||
|
"""处理插件卸载。"""
|
||||||
|
|
||||||
@Command(
|
@Command(
|
||||||
"management",
|
"management",
|
||||||
description="管理插件和组件的生命周期",
|
description="管理插件和组件的生命周期",
|
||||||
@@ -268,6 +274,25 @@ class PluginManagementPlugin(MaiBotPlugin):
|
|||||||
return components
|
return components
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def on_config_update(self, scope: str, config_data: dict[str, object], version: str) -> None:
|
||||||
|
"""处理配置热重载事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: 配置变更范围。
|
||||||
|
config_data: 最新配置数据。
|
||||||
|
version: 配置版本号。
|
||||||
|
"""
|
||||||
|
|
||||||
|
del scope
|
||||||
|
del config_data
|
||||||
|
del version
|
||||||
|
|
||||||
|
|
||||||
|
def create_plugin() -> PluginManagementPlugin:
|
||||||
|
"""创建插件管理插件实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PluginManagementPlugin: 新的插件管理插件实例。
|
||||||
|
"""
|
||||||
|
|
||||||
def create_plugin():
|
|
||||||
return PluginManagementPlugin()
|
return PluginManagementPlugin()
|
||||||
|
|||||||
Reference in New Issue
Block a user