Files
mai-bot/pytests/test_plugin_runtime.py
DrSmoothl 9dea6b0e6f feat: implement dedicated API registry and enhance API handling capabilities
- Added APIEntry and APIRegistry classes for managing plugin APIs.
- Updated PluginRunnerSupervisor to include API registry and methods for invoking APIs.
- Enhanced PluginRuntimeManager to support API registration and invocation.
- Created tests for API registration, invocation, and visibility between plugins.
- Refactored component handling to distinguish between runtime components and APIs.
2026-03-24 12:14:41 +08:00

2449 lines
87 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""插件运行时框架基础测试
验证协议层、传输层、RPC 通信链路的正确性。
"""
from types import SimpleNamespace
import asyncio
import json
import os
import sys
import pytest
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# SDK 包路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
# ─── 协议层测试 ───────────────────────────────────────────
class TestProtocol:
"""协议层测试"""
def test_envelope_create_and_serialize(self):
"""Envelope 创建与序列化"""
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
env = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.invoke_command",
plugin_id="test_plugin",
payload={"component_name": "greet", "args": {}},
)
assert env.request_id == 1
assert env.is_request()
assert env.method == "plugin.invoke_command"
# 测试 make_response
resp = env.make_response(payload={"success": True})
assert resp.is_response()
assert resp.request_id == 1
assert resp.payload["success"] is True
def test_envelope_make_error_response(self):
"""错误响应生成"""
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
env = Envelope(
request_id=42,
message_type=MessageType.REQUEST,
method="cap.request",
)
err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限")
assert err_resp.error is not None
assert err_resp.error["code"] == "E_UNAUTHORIZED"
assert err_resp.error["message"] == "没有权限"
def test_msgpack_codec(self):
"""MsgPack 编解码"""
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
codec = MsgPackCodec()
env = Envelope(
request_id=100,
message_type=MessageType.REQUEST,
method="test.method",
payload={"key": "value", "number": 42},
)
# 编码
data = codec.encode_envelope(env)
assert isinstance(data, bytes)
# 解码
decoded = codec.decode_envelope(data)
assert decoded.request_id == 100
assert decoded.method == "test.method"
assert decoded.payload["key"] == "value"
assert decoded.payload["number"] == 42
def test_json_codec(self):
"""JSON 编解码已移除,仅保留 MsgPack"""
pass
def test_request_id_generator(self):
"""请求 ID 生成器单调递增"""
from src.plugin_runtime.protocol.envelope import RequestIdGenerator
gen = RequestIdGenerator()
ids = [gen.next() for _ in range(100)]
assert ids == list(range(1, 101))
def test_error_codes(self):
"""错误码枚举"""
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
err = RPCError(ErrorCode.E_TIMEOUT, "请求超时")
assert err.code == ErrorCode.E_TIMEOUT
assert "E_TIMEOUT" in str(err)
# 序列化/反序列化
d = err.to_dict()
err2 = RPCError.from_dict(d)
assert err2.code == ErrorCode.E_TIMEOUT
# ─── 传输层测试 ───────────────────────────────────────────
class TestTransport:
"""传输层测试"""
@pytest.mark.asyncio
async def test_uds_connection_framing(self):
"""UDS 分帧协议测试"""
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
server = UDSTransportServer()
received = asyncio.Event()
received_data = []
async def handler(conn):
data = await conn.recv_frame()
received_data.append(data)
await conn.send_frame(b"pong")
received.set()
await server.start(handler)
address = server.get_address()
client = UDSTransportClient(address)
conn = await client.connect()
await conn.send_frame(b"ping")
# 等待服务端处理
await asyncio.wait_for(received.wait(), timeout=5.0)
assert received_data[0] == b"ping"
# 接收服务端回复
resp = await conn.recv_frame()
assert resp == b"pong"
await conn.close()
await server.stop()
@pytest.mark.asyncio
async def test_tcp_connection_framing(self):
"""TCP 分帧协议测试"""
from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient
server = TCPTransportServer()
received = asyncio.Event()
received_data = []
async def handler(conn):
data = await conn.recv_frame()
received_data.append(data)
await conn.send_frame(b"tcp_pong")
received.set()
await server.start(handler)
address = server.get_address()
host, port = address.split(":")
client = TCPTransportClient(host, int(port))
conn = await client.connect()
await conn.send_frame(b"tcp_ping")
await asyncio.wait_for(received.wait(), timeout=5.0)
assert received_data[0] == b"tcp_ping"
resp = await conn.recv_frame()
assert resp == b"tcp_pong"
await conn.close()
await server.stop()
@pytest.mark.asyncio
@pytest.mark.skipif(sys.platform != "win32", reason="Windows only")
async def test_named_pipe_connection_framing(self):
"""Windows Named Pipe 分帧协议测试"""
from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer
server = NamedPipeTransportServer()
received = asyncio.Event()
received_data = []
async def handler(conn):
data = await conn.recv_frame()
received_data.append(data)
await conn.send_frame(b"pipe_pong")
received.set()
await server.start(handler)
client = NamedPipeTransportClient(server.get_address())
conn = await client.connect()
await conn.send_frame(b"pipe_ping")
await asyncio.wait_for(received.wait(), timeout=5.0)
assert received_data[0] == b"pipe_ping"
resp = await conn.recv_frame()
assert resp == b"pipe_pong"
await conn.close()
await server.stop()
@pytest.mark.asyncio
async def test_transport_factory(self):
"""传输工厂测试"""
from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client
server = create_transport_server()
assert server is not None
# UDS 路径
client = create_transport_client("/tmp/test.sock")
assert client is not None
# Windows Named Pipe 地址
client = create_transport_client(r"\\.\pipe\maibot-test")
assert client is not None
# TCP 地址
client = create_transport_client("127.0.0.1:9999")
assert client is not None
# ─── Host 层测试 ──────────────────────────────────────────
class TestHost:
"""Host 端基础设施测试"""
def test_policy_engine(self):
"""策略引擎测试"""
from src.plugin_runtime.host.policy_engine import PolicyEngine
engine = PolicyEngine()
# 注册插件
token = engine.register_plugin(
plugin_id="test_plugin",
generation=1,
capabilities=["send.text", "db.query"],
)
assert token.plugin_id == "test_plugin"
assert "send.text" in token.capabilities
# 能力检查
ok, _ = engine.check_capability("test_plugin", "send.text")
assert ok
ok, reason = engine.check_capability("test_plugin", "llm.generate")
assert not ok
assert "未获授权" in reason
# 未注册插件
ok, reason = engine.check_capability("unknown", "send.text")
assert not ok
ok, reason = engine.check_capability("test_plugin", "send.text", generation=2)
assert not ok
assert "generation 不匹配" in reason
def test_policy_engine_allows_parallel_generations(self):
"""同一插件在热重载期间应允许 active/staged 两代并行持有能力令牌。"""
from src.plugin_runtime.host.policy_engine import PolicyEngine
engine = PolicyEngine()
engine.register_plugin("test_plugin", generation=1, capabilities=["send.text"])
engine.register_plugin("test_plugin", generation=2, capabilities=["send.text", "llm.generate"])
ok, _ = engine.check_capability("test_plugin", "send.text", generation=1)
assert ok is True
ok, _ = engine.check_capability("test_plugin", "llm.generate", generation=2)
assert ok is True
ok, reason = engine.check_capability("test_plugin", "llm.generate", generation=1)
assert ok is False
assert "未获授权" in reason
def test_circuit_breaker_removed(self):
"""熔断器已移除,验证 supervisor 不依赖它"""
pass
def test_circuit_breaker_registry_removed(self):
"""熔断器注册表已移除"""
pass
# ─── SDK 测试 ─────────────────────────────────────────────
class TestSDK:
"""SDK 框架测试"""
def test_component_decorators(self):
"""组件装饰器测试"""
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
from maibot_sdk.types import ActivationType, EventType
class TestPlugin(MaiBotPlugin):
@Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"])
async def handle_greet(self, **kwargs):
return True, "ok"
@Command("echo", pattern=r"^/echo")
async def handle_echo(self, **kwargs):
return True, "echoed", 2
@Tool("search", parameters={"query": {"type": "string"}})
async def handle_search(self, **kwargs):
return {"result": "found"}
@EventHandler("on_start", event_type=EventType.ON_START)
async def handle_start(self, **kwargs):
return True, False, "started"
plugin = TestPlugin()
components = plugin.get_components()
assert len(components) == 4
names = {c["name"] for c in components}
assert "greet" in names
assert "echo" in names
assert "search" in names
assert "on_start" in names
types = {c["type"] for c in components}
assert "action" in types
assert "command" in types
assert "tool" in types
assert "event_handler" in types
def test_plugin_context_not_initialized(self):
"""未初始化上下文时应报错"""
from maibot_sdk import MaiBotPlugin
plugin = MaiBotPlugin()
with pytest.raises(RuntimeError, match="尚未初始化"):
_ = plugin.ctx
def test_plugin_context_injection(self):
"""上下文注入测试"""
from maibot_sdk import MaiBotPlugin
from maibot_sdk.context import PluginContext
plugin = MaiBotPlugin()
ctx = PluginContext(plugin_id="test")
plugin._set_context(ctx)
assert plugin.ctx.plugin_id == "test"
assert plugin.ctx.send is not None
assert plugin.ctx.db is not None
assert plugin.ctx.llm is not None
assert plugin.ctx.config is not None
@pytest.mark.asyncio
async def test_runner_injected_context_binds_plugin_identity(self):
"""Runner 注入的上下文应忽略调用方伪造的 plugin_id。"""
from src.plugin_runtime.runner.runner_main import PluginRunner
class DummyRPCClient:
def __init__(self):
self.calls = []
async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000):
self.calls.append(
{
"method": method,
"plugin_id": plugin_id,
"payload": payload,
"timeout_ms": timeout_ms,
}
)
return SimpleNamespace(error=None, payload={"result": {"ok": True}})
class DummyPlugin:
def _set_context(self, ctx):
self.ctx = ctx
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
runner._rpc_client = DummyRPCClient()
plugin = DummyPlugin()
runner._inject_context("owner_plugin", plugin)
plugin.ctx._plugin_id = "forged_plugin"
result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1")
assert result == {"ok": True}
assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin"
assert runner._rpc_client.calls[0]["method"] == "cap.request"
@pytest.mark.asyncio
async def test_runner_applies_initial_plugin_config(self, tmp_path):
"""Runner 应在 on_load 前为支持的插件实例注入 config.toml。"""
from src.plugin_runtime.runner.runner_main import PluginRunner
class DummyPlugin:
def __init__(self):
self.configs = []
def set_plugin_config(self, config):
self.configs.append(config)
plugin_dir = tmp_path / "demo_plugin"
plugin_dir.mkdir()
(plugin_dir / "config.toml").write_text("[section]\nvalue = 1\n", encoding="utf-8")
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
plugin = DummyPlugin()
meta = SimpleNamespace(plugin_id="demo_plugin", plugin_dir=str(plugin_dir), instance=plugin)
runner._apply_plugin_config(meta)
assert plugin.configs == [{"section": {"value": 1}}]
@pytest.mark.asyncio
async def test_runner_config_update_refreshes_plugin_config_before_callback(self):
"""配置更新时应先刷新插件配置,再调用 on_config_update。"""
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
from src.plugin_runtime.runner.runner_main import PluginRunner
class DummyPlugin:
def __init__(self):
self.configs = []
self.updates = []
def set_plugin_config(self, config):
self.configs.append(config)
async def on_config_update(self, config, version):
self.updates.append((config, version, list(self.configs)))
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
plugin = DummyPlugin()
runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin)
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.config_updated",
plugin_id="demo_plugin",
payload={"config_data": {"enabled": True}, "config_version": "v2"},
)
response = await runner._handle_config_updated(envelope)
assert response.payload["acknowledged"] is True
assert plugin.configs == [{"enabled": True}]
assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])]
@pytest.mark.asyncio
async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch):
"""on_load 期间的 capability 调用应在 bootstrap 后生效。"""
from src.plugin_runtime.runner.runner_main import PluginRunner
class DummyRPCClient:
def __init__(self):
self.calls = []
async def connect_and_handshake(self):
return True
def register_method(self, method, handler):
return None
async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000):
self.calls.append(
{
"method": method,
"plugin_id": plugin_id,
"payload": payload,
"timeout_ms": timeout_ms,
}
)
if method == "cap.call":
bootstrap_methods = [call["method"] for call in self.calls[:-1]]
assert "plugin.bootstrap" in bootstrap_methods
return SimpleNamespace(error=None, payload={"success": True})
return SimpleNamespace(error=None, payload={"accepted": True})
async def disconnect(self):
return None
class DummyPlugin:
def __init__(self, runner):
self.runner = runner
def _set_context(self, ctx):
self.ctx = ctx
def get_components(self):
return [{"name": "handler", "type": "command", "metadata": {}}]
async def on_load(self):
result = await self.ctx.call_capability("send.text", text="hello", stream_id="stream-1")
assert result is True
self.runner._shutting_down = True
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
runner._rpc_client = DummyRPCClient()
plugin = DummyPlugin(runner)
meta = SimpleNamespace(
plugin_id="demo_plugin",
plugin_dir="/tmp/demo_plugin",
instance=plugin,
version="1.0.0",
capabilities_required=["send.text"],
)
monkeypatch.setattr(runner, "_install_log_handler", lambda: None)
monkeypatch.setattr(runner, "_uninstall_log_handler", lambda: asyncio.sleep(0))
monkeypatch.setattr(runner._loader, "discover_and_load", lambda plugin_dirs: [meta])
await runner.run()
methods = [call["method"] for call in runner._rpc_client.calls]
assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"]
class TestPluginSdkUsage:
"""验证仓库内插件按新 SDK 归一化返回值工作。"""
def test_runner_skips_signal_handler_registration_on_windows(self, monkeypatch):
"""Windows 下不应尝试注册 add_signal_handler。"""
from src.plugin_runtime.runner import runner_main
registered_signals = []
class DummyLoop:
def add_signal_handler(self, sig, callback):
registered_signals.append((sig, callback))
monkeypatch.setattr(runner_main.sys, "platform", "win32")
runner_main._install_shutdown_signal_handlers(lambda: None, DummyLoop())
assert not registered_signals
@pytest.mark.asyncio
async def test_builtin_emoji_plugin_handles_normalized_results(self):
from maibot_sdk.context import PluginContext
from src.plugins.built_in.emoji_plugin.plugin import EmojiPlugin
async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None):
assert method == "cap.request"
assert payload is not None
capability = payload["capability"]
return {
"emoji.get_random": {
"success": True,
"emojis": [{"base64": "img-1", "emotion": "happy"}],
},
"message.get_recent": {"success": True, "messages": [{"id": 1}]},
"message.build_readable": {"success": True, "text": "最近消息"},
"llm.generate": {"success": True, "response": "happy", "reasoning": "", "model_name": "m"},
"send.emoji": {"success": True},
}[capability]
plugin = EmojiPlugin()
plugin._set_context(PluginContext(plugin_id="emoji", rpc_call=fake_rpc_call))
success, message = await plugin.handle_emoji(stream_id="stream-1", reasoning="测试", chat_id="chat-1")
assert success is True
assert "成功发送表情包" in message
@pytest.mark.asyncio
async def test_tts_plugin_uses_send_custom_bool_result(self):
from maibot_sdk.context import PluginContext
from src.plugins.built_in.tts_plugin.plugin import TTSPlugin
async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None):
assert method == "cap.request"
assert payload is not None
assert payload["capability"] == "send.custom"
return {"success": True}
plugin = TTSPlugin()
plugin._set_context(PluginContext(plugin_id="tts", rpc_call=fake_rpc_call))
success, message = await plugin.handle_tts_action(
stream_id="stream-1",
action_data={"voice_text": "你好!!!"},
)
assert success is True
assert message == "TTS动作执行成功"
@pytest.mark.asyncio
async def test_hello_world_plugin_handles_random_emoji_list(self):
from maibot_sdk.context import PluginContext
from plugins.hello_world_plugin.plugin import HelloWorldPlugin
async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None):
assert method == "cap.request"
assert payload is not None
capability = payload["capability"]
return {
"emoji.get_random": {"success": True, "emojis": [{"base64": "img-1"}, {"base64": "img-2"}]},
"send.forward": {"success": True},
}[capability]
plugin = HelloWorldPlugin()
plugin._set_context(PluginContext(plugin_id="hello", rpc_call=fake_rpc_call))
success, message, should_continue = await plugin.handle_random_emojis(stream_id="stream-1")
assert success is True
assert message == "已发送随机表情包"
assert should_continue is True
# ─── 端到端集成测试 ────────────────────────────────────────
class TestE2E:
"""端到端集成测试Host + Runner 通信)"""
@pytest.mark.asyncio
async def test_handshake(self):
"""Host-Runner 握手流程测试"""
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server
import secrets
session_token = secrets.token_hex(16)
codec = MsgPackCodec()
handshake_done = asyncio.Event()
server_result = {}
async def server_handler(conn):
# 接收握手
data = await conn.recv_frame()
env = codec.decode_envelope(data)
assert env.method == "runner.hello"
hello = HelloPayload.model_validate(env.payload)
assert hello.session_token == session_token
# 发送响应
resp_payload = HelloResponsePayload(
accepted=True,
host_version="1.0",
assigned_generation=1,
)
resp = env.make_response(payload=resp_payload.model_dump())
await conn.send_frame(codec.encode_envelope(resp))
server_result["runner_id"] = hello.runner_id
handshake_done.set()
# 保持连接一会儿
await asyncio.sleep(1.0)
server = create_transport_server()
await server.start(server_handler)
# 客户端握手
client = create_transport_client(server.get_address())
conn = await client.connect()
hello = HelloPayload(
runner_id="test-runner",
sdk_version="1.0.0",
session_token=session_token,
)
env = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="runner.hello",
payload=hello.model_dump(),
)
await conn.send_frame(codec.encode_envelope(env))
resp_data = await conn.recv_frame()
resp = codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(resp.payload)
assert resp_payload.accepted
assert resp_payload.assigned_generation == 1
await asyncio.wait_for(handshake_done.wait(), timeout=5.0)
assert server_result["runner_id"] == "test-runner"
await conn.close()
await server.stop()
# ─── Manifest 校验测试 ─────────────────────────────────────
class TestManifestValidator:
"""Manifest 校验器测试"""
def test_valid_manifest(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {
"manifest_version": 1,
"name": "test_plugin",
"version": "1.0.0",
"description": "测试插件",
"author": "test",
}
assert validator.validate(manifest) is True
assert len(validator.errors) == 0
def test_missing_required_fields(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {"manifest_version": 1}
assert validator.validate(manifest) is False
assert len(validator.errors) >= 4 # name, version, description, author
def test_unsupported_manifest_version(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {
"manifest_version": 999,
"name": "test",
"version": "1.0",
"description": "d",
"author": "a",
}
assert validator.validate(manifest) is False
assert any("manifest_version" in e for e in validator.errors)
def test_host_version_compatibility(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator(host_version="0.8.5")
manifest = {
"name": "test",
"version": "1.0",
"description": "d",
"author": "a",
"host_application": {"min_version": "0.9.0"},
}
assert validator.validate(manifest) is False
assert any("Host 版本不兼容" in e for e in validator.errors)
def test_recommended_fields_warning(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {
"name": "test",
"version": "1.0",
"description": "d",
"author": "a",
}
validator.validate(manifest)
assert len(validator.warnings) >= 3 # license, keywords, categories
class TestVersionComparator:
"""版本号比较器测试"""
def test_normalize(self):
from src.plugin_runtime.runner.manifest_validator import VersionComparator
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
assert VersionComparator.normalize_version("1.2") == "1.2.0"
assert VersionComparator.normalize_version("") == "0.0.0"
def test_compare(self):
from src.plugin_runtime.runner.manifest_validator import VersionComparator
assert VersionComparator.compare("0.8.0", "0.8.0") == 0
assert VersionComparator.compare("0.8.0", "0.9.0") == -1
assert VersionComparator.compare("1.0.0", "0.9.0") == 1
def test_is_in_range(self):
from src.plugin_runtime.runner.manifest_validator import VersionComparator
ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0")
assert ok
ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0")
assert not ok
ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0")
assert not ok
# ─── 依赖解析测试 ──────────────────────────────────────────
class TestDependencyResolution:
"""插件依赖解析测试"""
def test_topological_sort(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
"auth": (
"dir_auth",
{"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]},
"plugin.py",
),
"api": (
"dir_api",
{"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]},
"plugin.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
assert len(failed) == 0
assert order.index("core") < order.index("auth")
assert order.index("auth") < order.index("api")
def test_missing_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
"plugin_a": (
"dir_a",
{
"name": "plugin_a",
"version": "1.0",
"description": "d",
"author": "a",
"dependencies": ["nonexistent"],
},
"plugin.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
assert "plugin_a" in failed
assert "缺少依赖" in failed["plugin_a"]
def test_circular_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
"a": (
"dir_a",
{"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]},
"p.py",
),
"b": (
"dir_b",
{"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]},
"p.py",
),
}
order, failed = loader._resolve_dependencies(candidates)
assert len(failed) >= 1 # 至少一个循环插件被标记
def test_loader_supports_package_imports_inside_create_plugin(self, tmp_path):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
plugin_dir = plugin_root / "grok_search_plugin"
plugin_dir.mkdir()
(plugin_dir / "_manifest.json").write_text(
json.dumps(
{
"name": "grok_search_plugin",
"version": "1.0.0",
"description": "demo",
"author": "tester",
}
),
encoding="utf-8",
)
(plugin_dir / "__init__.py").write_text("VALUE = 1\n", encoding="utf-8")
(plugin_dir / "services.py").write_text("def answer():\n return 42\n", encoding="utf-8")
(plugin_dir / "plugin.py").write_text(
"class DemoPlugin:\n"
" pass\n\n"
"def create_plugin():\n"
" from grok_search_plugin.services import answer\n"
" plugin = DemoPlugin()\n"
" plugin.answer = answer\n"
" return plugin\n",
encoding="utf-8",
)
loader = PluginLoader()
loaded = loader.discover_and_load([str(plugin_root)])
assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"]
assert loader.failed_plugins == {}
assert loaded[0].instance.answer() == 42
def test_isolate_sys_path_preserves_plugin_dirs(self):
from src.plugin_runtime.runner import runner_main
plugin_root = os.path.normpath("/tmp/maibot-plugin-root")
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
try:
if plugin_root in sys.path:
sys.path.remove(plugin_root)
runner_main._isolate_sys_path([plugin_root])
assert plugin_root in sys.path
finally:
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
# ─── Host-side ComponentRegistry 测试 ──────────────────────
class TestComponentRegistry:
"""Host-side 组件注册表测试"""
def test_register_and_query(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component(
"greet",
"action",
"plugin_a",
{
"description": "打招呼",
"activation_type": "keyword",
"activation_keywords": ["hi"],
},
)
reg.register_component(
"help",
"command",
"plugin_a",
{
"command_pattern": r"^/help",
},
)
reg.register_component(
"search",
"tool",
"plugin_b",
{
"description": "搜索",
},
)
stats = reg.get_stats()
assert stats["total"] == 3
assert stats["action"] == 1
assert stats["command"] == 1
assert stats["tool"] == 1
def test_query_by_type(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("a1", "action", "p1", {})
reg.register_component("a2", "action", "p2", {})
actions = reg.get_components_by_type("action")
assert len(actions) == 2
def test_find_command_by_text(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component(
"help",
"command",
"p1",
{
"command_pattern": r"^/help",
},
)
reg.register_component(
"echo",
"command",
"p1",
{
"command_pattern": r"^/echo\s",
},
)
match = reg.find_command_by_text("/help me")
assert match is not None
comp, groups = match
assert comp.name == "help"
match = reg.find_command_by_text("/echo hello")
assert match is not None
comp, groups = match
assert comp.name == "echo"
match = reg.find_command_by_text("no match")
assert match is None
def test_enable_disable(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("a1", "action", "p1", {})
reg.set_component_enabled("p1.a1", False)
actions = reg.get_components_by_type("action", enabled_only=True)
assert len(actions) == 0
actions = reg.get_components_by_type("action", enabled_only=False)
assert len(actions) == 1
def test_remove_by_plugin(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("a1", "action", "p1", {})
reg.register_component("c1", "command", "p1", {})
reg.register_component("a2", "action", "p2", {})
removed = reg.remove_components_by_plugin("p1")
assert removed == 2
assert reg.get_stats()["total"] == 1
def test_reregister_same_plugin_replaces_component_set(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_plugin_components(
"p1",
[
{"name": "a1", "component_type": "action", "metadata": {}},
{"name": "a2", "component_type": "action", "metadata": {}},
],
)
reg.remove_components_by_plugin("p1")
reg.register_plugin_components(
"p1",
[
{"name": "a1", "component_type": "action", "metadata": {}},
],
)
assert reg.get_component("p1.a1") is not None
assert reg.get_component("p1.a2") is None
def test_event_handlers_sorted_by_weight(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component(
"h_low",
"event_handler",
"p1",
{
"event_type": "on_message",
"weight": 10,
},
)
reg.register_component(
"h_high",
"event_handler",
"p2",
{
"event_type": "on_message",
"weight": 100,
},
)
handlers = reg.get_event_handlers("on_message")
assert handlers[0].name == "h_high"
assert handlers[1].name == "h_low"
def test_tools_for_llm(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component(
"search",
"tool",
"p1",
{
"description": "搜索工具",
"parameters_raw": {"query": {"type": "string"}},
},
)
tools = reg.get_tools_for_llm()
assert len(tools) == 1
assert tools[0]["name"] == "p1.search"
assert tools[0]["parameters"]["query"]["type"] == "string"
# ─── EventDispatcher 测试 ─────────────────────────────────
class TestEventDispatcher:
"""Host-side 事件分发器测试"""
@pytest.mark.asyncio
async def test_dispatch_non_blocking(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
reg = ComponentRegistry()
reg.register_component(
"h1",
"event_handler",
"p1",
{
"event_type": "on_start",
"weight": 0,
"intercept_message": False,
},
)
dispatcher = EventDispatcher(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append((plugin_id, comp_name))
return {"success": True, "continue_processing": True}
should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke)
assert should_continue
# 非阻塞分发是异步的,等一下让 task 完成
await asyncio.sleep(0.1)
assert len(call_log) == 1
assert call_log[0] == ("p1", "h1")
@pytest.mark.asyncio
async def test_dispatch_intercepting(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
reg = ComponentRegistry()
reg.register_component(
"filter",
"event_handler",
"p1",
{
"event_type": "on_message_pre_process",
"weight": 100,
"intercept_message": True,
},
)
dispatcher = EventDispatcher(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {
"success": True,
"continue_processing": False,
"modified_message": {"plain_text": "filtered"},
}
should_continue, modified = await dispatcher.dispatch_event(
"on_message_pre_process", mock_invoke, message={"plain_text": "hello"}
)
assert not should_continue
assert modified is not None
assert modified["plain_text"] == "filtered"
class TestEventBus:
"""核心事件总线与 IPC 桥接测试"""
@pytest.mark.asyncio
async def test_bridge_preserves_modified_message(self, monkeypatch):
import types
fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model")
fake_message_data_model.ReplyContentType = object
fake_message_data_model.ReplyContent = object
fake_message_data_model.ForwardNode = object
fake_message_data_model.ReplySetModel = object
monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model)
from src.core.event_bus import EventBus
from src.core.types import EventType, MaiMessages
from src.plugin_runtime import integration as integration_module
bus = EventBus()
async def noop_handler(message):
return True, message
bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True)
class FakeManager:
is_running = True
async def bridge_event(self, event_type_value, message_dict=None, extra_args=None):
assert event_type_value == EventType.ON_MESSAGE.value
return True, {"plain_text": "modified by ipc"}
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
original = MaiMessages(plain_text="original")
continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original)
assert continue_flag is True
assert modified is not None
assert modified.plain_text == "modified by ipc"
assert original.plain_text == "original"
# ─── MaiMessages 测试 ─────────────────────────────────────
class TestMaiMessages:
"""统一消息模型测试"""
def test_create_and_serialize(self):
from maibot_sdk.messages import MaiMessages, MessageSegment
msg = MaiMessages(
message_segments=[MessageSegment(type="text", data={"text": "hello"})],
plain_text="hello",
stream_id="stream_1",
)
d = msg.to_rpc_dict()
assert d["plain_text"] == "hello"
assert len(d["message_segments"]) == 1
msg2 = MaiMessages.from_rpc_dict(d)
assert msg2.plain_text == "hello"
def test_deepcopy(self):
from maibot_sdk.messages import MaiMessages
msg = MaiMessages(plain_text="original")
msg2 = msg.deepcopy()
msg2.plain_text = "modified"
assert msg.plain_text == "original"
def test_modify_flags(self):
from maibot_sdk.messages import MaiMessages
from maibot_sdk.types import ModifyFlag
msg = MaiMessages(plain_text="hello")
assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT)
msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False)
assert not msg.modify_prompt("new prompt")
assert msg.llm_prompt is None
assert msg.modify_response("new response")
assert msg.llm_response_content == "new response"
# ─── WorkflowExecutor 测试 ────────────────────────────────
class TestWorkflowExecutor:
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
@pytest.mark.asyncio
async def test_empty_pipeline_completes(self):
"""无任何 workflow_step 注册时pipeline 全阶段跳过,状态 completed"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "continue"}
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "completed"
assert result.return_message == "workflow completed"
assert len(ctx.timings) == 6 # 6 stages
@pytest.mark.asyncio
async def test_blocking_hook_modifies_message(self):
"""blocking hook 可以修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"upper",
"workflow_step",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
msg = args.get("message", {})
return {
"hook_result": "continue",
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
}
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "hello"},
)
assert result.status == "completed"
assert final_msg["plain_text"] == "HELLO"
assert len(ctx.modification_log) == 1
assert ctx.modification_log[0].stage == "pre_process"
@pytest.mark.asyncio
async def test_abort_stops_pipeline(self):
"""HookResult.ABORT 立即终止 pipeline"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"blocker",
"workflow_step",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "abort"}
result, _, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "aborted"
assert result.stopped_at == "pre_process"
@pytest.mark.asyncio
async def test_skip_stage(self):
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# high-priority hook 返回 skip_stage
reg.register_component(
"skipper",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
},
)
# low-priority hook 不应被执行
reg.register_component(
"checker",
"workflow_step",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "skipper":
return {"hook_result": "skip_stage"}
return {"hook_result": "continue"}
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
# 只有 skipper 被调用checker 被跳过
assert call_log == ["skipper"]
@pytest.mark.asyncio
async def test_pre_filter(self):
"""filter 条件不匹配时跳过 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"only_dm",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
"filter": {"chat_type": "direct"},
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue"}
# 不匹配 filter —— hook 不应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
assert not call_log
# 匹配 filter —— hook 应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
assert call_log == ["only_dm"]
@pytest.mark.asyncio
async def test_error_policy_skip(self):
"""error_policy=skip 时跳过失败的 hook 继续执行"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
"error_policy": "skip",
},
)
reg.register_component(
"ok_step",
"workflow_step",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "failer":
raise RuntimeError("boom")
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
assert "failer" in call_log
assert "ok_step" in call_log
assert any("boom" in e for e in ctx.errors)
@pytest.mark.asyncio
async def test_error_policy_abort(self):
"""error_policy=abort默认时 pipeline 失败"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
# error_policy defaults to "abort"
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
raise RuntimeError("fatal")
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "failed"
assert result.stopped_at == "ingress"
@pytest.mark.asyncio
async def test_nonblocking_hooks_concurrent(self):
"""non-blocking hook 并发执行,不修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
for i in range(3):
reg.register_component(
f"nb_{i}",
"workflow_step",
f"p{i}",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
# non-blocking 的 modified_message 被忽略
assert final_msg["plain_text"] == "original"
# 给异步 task 时间完成
await asyncio.sleep(0.1)
assert result.status == "completed"
@pytest.mark.asyncio
async def test_nonblocking_tasks_are_retained_until_completion(self):
"""execute 返回后non-blocking task 仍应保持强引用直到执行完成。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"observer",
"workflow_step",
"p1",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg)
started = asyncio.Event()
release = asyncio.Event()
async def mock_invoke(plugin_id, comp_name, args):
started.set()
await release.wait()
return {"hook_result": "continue"}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
await asyncio.sleep(0)
assert result.status == "completed"
assert final_msg["plain_text"] == "original"
assert started.is_set()
assert len(executor._background_tasks) == 1
release.set()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert not executor._background_tasks
@pytest.mark.asyncio
async def test_command_routing(self):
"""PLAN 阶段内置命令路由"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"help",
"command",
"p1",
{
"command_pattern": r"^/help",
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "help":
return {"output": "帮助信息"}
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
assert result.status == "completed"
assert ctx.matched_command == "p1.help"
cmd_result = ctx.get_stage_output("plan", "command_result")
assert cmd_result is not None
assert cmd_result["output"] == "帮助信息"
@pytest.mark.asyncio
async def test_stage_outputs(self):
"""stage_outputs 数据在阶段间传递"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# ingress 阶段写入数据
reg.register_component(
"writer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
},
)
# pre_process 阶段读取数据
reg.register_component(
"reader",
"workflow_step",
"p2",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "writer":
return {
"hook_result": "continue",
"stage_output": {"parsed_intent": "greeting"},
}
if comp_name == "reader":
# 验证 stage_outputs 被传递过来
outputs = args.get("stage_outputs", {})
ingress_data = outputs.get("ingress", {})
assert ingress_data.get("parsed_intent") == "greeting"
return {"hook_result": "continue"}
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
assert result.status == "completed"
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
class TestRPCServer:
"""RPC Server 代际保护测试"""
def test_ignore_stale_generation_response(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
class DummyTransport:
async def start(self, handler):
return None
async def stop(self):
return None
def get_address(self):
return "dummy"
server = RPCServer(transport=DummyTransport())
server._runner_generation = 2
loop = asyncio.new_event_loop()
try:
future = loop.create_future()
server._pending_requests[1] = (future, 2)
stale_response = Envelope(
request_id=1,
message_type=MessageType.RESPONSE,
method="plugin.health",
generation=1,
payload={"healthy": True},
)
server._handle_response(stale_response)
assert not future.done()
assert 1 in server._pending_requests
finally:
loop.close()
@pytest.mark.asyncio
async def test_send_queue_backpressure_is_enforced(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
class DummyTransport:
async def start(self, handler):
return None
async def stop(self):
return None
def get_address(self):
return "dummy"
class BlockingConnection:
def __init__(self):
self.is_closed = False
self.release = asyncio.Event()
async def send_frame(self, data):
await self.release.wait()
async def close(self):
self.is_closed = True
server = RPCServer(transport=DummyTransport(), send_queue_size=1)
await server.start()
conn = BlockingConnection()
server._connection = conn
server._runner_generation = 1
first_send = asyncio.create_task(server.send_event("runner.log_batch"))
await asyncio.sleep(0)
second_send = asyncio.create_task(server.send_event("runner.log_batch"))
await asyncio.sleep(0)
with pytest.raises(RPCError) as exc_info:
await server.send_event("runner.log_batch")
assert exc_info.value.code == ErrorCode.E_BACKPRESSURE
conn.release.set()
await asyncio.gather(first_send, second_send)
await server.stop()
class TestRPCClient:
"""Runner RPCClient 后台任务生命周期测试"""
@pytest.mark.asyncio
async def test_background_tasks_retained_and_cancelled_on_disconnect(self):
from src.plugin_runtime.runner.rpc_client import RPCClient
client = RPCClient(host_address="dummy", session_token="token")
release = asyncio.Event()
async def pending_task():
await release.wait()
task = asyncio.create_task(pending_task())
client._track_background_task(task)
assert task in client._background_tasks
await asyncio.sleep(0)
assert task in client._background_tasks
await client.disconnect()
assert task.cancelled() is True
assert not client._background_tasks
class TestSupervisor:
"""Supervisor 生命周期边界测试"""
@staticmethod
def _build_register_payload(plugin_id: str = "plugin_a", component_names=None):
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
component_names = component_names or ["handler"]
return RegisterComponentsPayload(
plugin_id=plugin_id,
plugin_version="1.0.0",
components=[
ComponentDeclaration(
name=name,
component_type="event_handler",
plugin_id=plugin_id,
metadata={"event_type": "on_message"},
)
for name in component_names
],
capabilities_required=["send.text"],
)
@staticmethod
def _make_process(pid: int):
class FakeProcess:
def __init__(self):
self.pid = pid
self.returncode = None
self.stdout = None
self.stderr = None
self.terminated = False
self.killed = False
def terminate(self):
self.terminated = True
self.returncode = 0
def kill(self):
self.killed = True
self.returncode = -9
async def wait(self):
return self.returncode
return FakeProcess()
@pytest.mark.asyncio
async def test_reload_waits_for_target_generation(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import HealthPayload
supervisor = PluginSupervisor(plugin_dirs=[])
old_process = self._make_process(1)
new_process = self._make_process(2)
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
self.committed = False
self.staging_started = False
def reset_session_token(self):
self.session_token = "new-fake-token"
return self.session_token
def restore_session_token(self, token):
self.session_token = token
def begin_staged_takeover(self):
self.staging_started = True
self.staged_generation = 2
async def commit_staged_takeover(self):
self.runner_generation = self.staged_generation
self.staged_generation = 0
self.committed = True
async def rollback_staged_takeover(self):
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
assert target_generation == 2
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
supervisor._rpc_server = FakeRPCServer()
supervisor._runner_process = old_process
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[])
supervisor._runner_ready_events[2] = asyncio.Event()
supervisor._runner_ready_events[2].set()
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
reloaded = await supervisor.reload_plugins("test")
assert reloaded is True
assert supervisor._runner_process is new_process
assert supervisor._rpc_server.committed is True
assert old_process.terminated is True
@pytest.mark.asyncio
async def test_reload_restores_runtime_state_on_failure(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
supervisor = PluginSupervisor(plugin_dirs=[])
old_process = self._make_process(1)
new_process = self._make_process(2)
old_reg = self._build_register_payload()
supervisor._runner_process = old_process
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
supervisor._rebuild_runtime_state()
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
self.rolled_back = False
def reset_session_token(self):
self.session_token = "new-fake-token"
return self.session_token
def restore_session_token(self, token):
self.session_token = token
def begin_staged_takeover(self):
self.staged_generation = 2
async def commit_staged_takeover(self):
self.runner_generation = self.staged_generation
self.staged_generation = 0
async def rollback_staged_takeover(self):
self.rolled_back = True
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
raise RuntimeError("new runner unhealthy")
supervisor._rpc_server = FakeRPCServer()
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[])
supervisor._runner_ready_events[2] = asyncio.Event()
supervisor._runner_ready_events[2].set()
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
reloaded = await supervisor.reload_plugins("test")
assert reloaded is False
assert supervisor._runner_process is old_process
assert supervisor._rpc_server.rolled_back is True
assert old_reg.plugin_id in supervisor._registered_plugins
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
@pytest.mark.asyncio
async def test_reload_rebuilds_exact_component_set(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import HealthPayload
supervisor = PluginSupervisor(plugin_dirs=[])
old_process = self._make_process(1)
new_process = self._make_process(2)
old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"])
new_reg = self._build_register_payload("plugin_a", component_names=["handler"])
supervisor._runner_process = old_process
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
supervisor._rebuild_runtime_state()
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
def reset_session_token(self):
self.session_token = "new-fake-token"
return self.session_token
def restore_session_token(self, token):
self.session_token = token
def begin_staged_takeover(self):
self.staged_generation = 2
async def commit_staged_takeover(self):
self.runner_generation = self.staged_generation
self.staged_generation = 0
async def rollback_staged_takeover(self):
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
supervisor._rpc_server = FakeRPCServer()
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg
supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[])
supervisor._runner_ready_events[2] = asyncio.Event()
supervisor._runner_ready_events[2].set()
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
reloaded = await supervisor.reload_plugins("test")
assert reloaded is True
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
assert supervisor.component_registry.get_component("plugin_a.obsolete") is None
@pytest.mark.asyncio
async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
supervisor = PluginSupervisor(plugin_dirs=[], runner_spawn_timeout_sec=0.01)
old_process = self._make_process(1)
new_process = self._make_process(2)
old_reg = self._build_register_payload()
supervisor._runner_process = old_process
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
supervisor._rebuild_runtime_state()
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.staged_generation = 0
self.is_connected = True
self.session_token = "fake-token"
self.rolled_back = False
def reset_session_token(self):
self.session_token = "new-fake-token"
return self.session_token
def restore_session_token(self, token):
self.session_token = token
def begin_staged_takeover(self):
self.staged_generation = 2
async def commit_staged_takeover(self):
raise AssertionError("runner.ready 未到达前不应提交 staged takeover")
async def rollback_staged_takeover(self):
self.rolled_back = True
self.staged_generation = 0
def has_generation(self, generation):
return generation in {self.runner_generation, self.staged_generation}
async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs):
raise AssertionError("runner.ready 未到达前不应执行健康检查")
supervisor._rpc_server = FakeRPCServer()
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a")
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
reloaded = await supervisor.reload_plugins("test")
assert reloaded is False
assert supervisor._runner_process is old_process
assert supervisor._rpc_server.rolled_back is True
@pytest.mark.asyncio
async def test_attach_stderr_drain_drains_stream(self):
"""_attach_stderr_drain 为 stderr 创建排空任务,读完后任务自动完成。"""
from src.plugin_runtime.host.supervisor import PluginSupervisor
supervisor = PluginSupervisor(plugin_dirs=[])
stderr = asyncio.StreamReader()
stderr.feed_data(b"fatal startup error\n")
stderr.feed_eof()
# stdout=None 模拟新架构(不再捕获 stdout
process = SimpleNamespace(pid=99, stdout=None, stderr=stderr)
supervisor._attach_stderr_drain(process)
# 给 drain task 足够时间消费完数据
await asyncio.sleep(0.05)
assert supervisor._stderr_drain_task is None or supervisor._stderr_drain_task.done()
class TestIntegration:
"""运行时集成层启动/清理测试"""
@pytest.mark.asyncio
async def test_cap_database_get_with_filters_does_not_reference_unbound_key_value(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
import src.common.database.database_model as real_db_models
from src.services import database_service as real_database_service
captured: dict[str, object] = {}
class DummyModel:
pass
async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False):
captured["model_class"] = model_class
captured["filters"] = filters
captured["limit"] = limit
captured["order_by"] = order_by
captured["single_result"] = single_result
return [{"id": 1}]
monkeypatch.setattr(real_database_service, "db_get", fake_db_get)
monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False)
result = await integration_module.PluginRuntimeManager._cap_database_get(
"plugin_a",
"database.get",
{
"table": "DemoTable",
"filters": {"status": "active"},
"limit": 5,
},
)
assert result == {"success": True, "result": [{"id": 1}]}
assert captured["model_class"] is DummyModel
assert captured["filters"] == {"status": "active"}
assert captured["limit"] == 5
assert captured["single_result"] is False
@pytest.mark.asyncio
async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
from src.plugin_runtime.host.component_registry import ComponentRegistry
class FakeSupervisor:
def __init__(self, plugin_id: str):
self.component_registry = ComponentRegistry()
self.component_registry.register_component(
name="shared",
component_type="tool",
plugin_id=plugin_id,
metadata={},
)
class FakeManager:
def __init__(self):
self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
manager = integration_module.PluginRuntimeManager()
manager._builtin_supervisor = FakeSupervisor("plugin_a")
manager._third_party_supervisor = FakeSupervisor("plugin_b")
result = await manager._cap_component_enable(
"plugin_a",
"component.enable",
{"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""},
)
assert result["success"] is False
assert "组件名不唯一" in result["error"]
@pytest.mark.asyncio
async def test_component_disable_rejects_non_global_scope(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
from src.plugin_runtime.host.component_registry import ComponentRegistry
class FakeSupervisor:
def __init__(self):
self.component_registry = ComponentRegistry()
self.component_registry.register_component(
name="handler",
component_type="tool",
plugin_id="plugin_a",
metadata={},
)
class FakeManager:
def __init__(self):
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
manager = integration_module.PluginRuntimeManager()
manager._builtin_supervisor = FakeSupervisor()
result = await manager._cap_component_disable(
"plugin_a",
"component.disable",
{"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"},
)
assert result["success"] is False
assert "仅支持全局组件禁用" in result["error"]
@pytest.mark.asyncio
async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
instances = []
class FakeCapabilityService:
def register_capability(self, name, impl):
return None
class FakeSupervisor:
def __init__(self, plugin_dirs=None, socket_path=None):
self.plugin_dirs = plugin_dirs or []
self.capability_service = FakeCapabilityService()
self.stopped = False
instances.append(self)
async def start(self):
if len(instances) == 2 and self is instances[1]:
raise RuntimeError("boom")
async def stop(self):
self.stopped = True
monkeypatch.setattr(
integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"])
)
monkeypatch.setattr(
integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"])
)
import src.plugin_runtime.host.supervisor as supervisor_module
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
manager = integration_module.PluginRuntimeManager()
await manager.start()
assert manager.is_running is False
assert len(instances) == 2
assert instances[0].stopped is True
@pytest.mark.asyncio
async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path):
from src.config.file_watcher import FileChange
from src.plugin_runtime import integration as integration_module
import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
alpha_dir = builtin_root / "alpha"
beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
monkeypatch.chdir(tmp_path)
class FakeSupervisor:
def __init__(self, plugin_dirs, registered_plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = registered_plugins
self.reload_reasons = []
self.config_updates = []
async def reload_plugins(self, plugin_ids=None, reason="manual"):
self.reload_reasons.append((plugin_ids, reason))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version))
return True
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], {"alpha": object()})
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"beta": object()})
changes = [
FileChange(change_type=1, path=beta_dir / "plugin.py"),
]
refresh_calls = []
def fake_refresh() -> None:
refresh_calls.append(True)
manager._refresh_plugin_config_watch_subscriptions = fake_refresh
await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == []
assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True]
@pytest.mark.asyncio
async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
alpha_dir = builtin_root / "alpha"
beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
monkeypatch.chdir(tmp_path)
class FakeSupervisor:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
self.reload_calls = []
async def reload_plugin(self, plugin_id, reason="manual"):
self.reload_calls.append((plugin_id, reason))
return True
manager = integration_module.PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
refresh_calls = []
def fake_refresh() -> None:
refresh_calls.append(True)
manager._refresh_plugin_config_watch_subscriptions = fake_refresh
await manager._handle_plugin_config_changes(
"alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")],
)
assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
assert manager._third_party_supervisor.reload_calls == []
assert refresh_calls == [True]
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
from src.plugin_runtime import integration as integration_module
import json
builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins"
alpha_dir = builtin_root / "alpha"
beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True)
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
class FakeWatcher:
def __init__(self):
self.subscriptions = []
self.unsubscribed = []
def subscribe(self, callback, *, paths=None, change_types=None):
subscription_id = f"sub-{len(self.subscriptions) + 1}"
self.subscriptions.append({"id": subscription_id, "callback": callback, "paths": tuple(paths or ())})
return subscription_id
def unsubscribe(self, subscription_id):
self.unsubscribed.append(subscription_id)
return True
class FakeSupervisor:
def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
manager = integration_module.PluginRuntimeManager()
manager._plugin_file_watcher = FakeWatcher()
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
manager._refresh_plugin_config_watch_subscriptions()
assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"alpha", "beta"}
assert {
subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions
} == {alpha_dir / "config.toml", beta_dir / "config.toml"}
@pytest.mark.asyncio
async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
class FakeSupervisor:
def __init__(self):
self._registered_plugins = {"alpha": object()}
async def reload_plugins(self, reason="manual"):
return False
class FakeManager:
def __init__(self):
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin(
"plugin_a",
"component.reload_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
assert "已回滚" in result["error"]
@pytest.mark.asyncio
async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
(plugin_root / "alpha").mkdir()
class FakeSupervisor:
def __init__(self):
self._registered_plugins = {}
self._plugin_dirs = [str(plugin_root)]
async def reload_plugins(self, reason="manual"):
return False
class FakeManager:
def __init__(self):
self.supervisors = [FakeSupervisor()]
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
result = await integration_module.PluginRuntimeManager._cap_component_load_plugin(
"plugin_a",
"component.load_plugin",
{"plugin_name": "alpha"},
)
assert result["success"] is False
assert "已回滚" in result["error"]