1337 lines
46 KiB
Python
1337 lines
46 KiB
Python
"""插件运行时框架基础测试
|
||
|
||
验证协议层、传输层、RPC 通信链路的正确性。
|
||
"""
|
||
|
||
from types import SimpleNamespace
|
||
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
|
||
import pytest
|
||
|
||
# 确保项目根目录在 sys.path 中
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||
# SDK 包路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
|
||
|
||
|
||
# ─── 协议层测试 ───────────────────────────────────────────
|
||
|
||
class TestProtocol:
|
||
"""协议层测试"""
|
||
|
||
def test_envelope_create_and_serialize(self):
|
||
"""Envelope 创建与序列化"""
|
||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||
|
||
env = Envelope(
|
||
request_id=1,
|
||
message_type=MessageType.REQUEST,
|
||
method="plugin.invoke_command",
|
||
plugin_id="test_plugin",
|
||
payload={"component_name": "greet", "args": {}},
|
||
)
|
||
|
||
assert env.request_id == 1
|
||
assert env.is_request()
|
||
assert env.method == "plugin.invoke_command"
|
||
|
||
# 测试 make_response
|
||
resp = env.make_response(payload={"success": True})
|
||
assert resp.is_response()
|
||
assert resp.request_id == 1
|
||
assert resp.payload["success"] is True
|
||
|
||
def test_envelope_make_error_response(self):
|
||
"""错误响应生成"""
|
||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||
|
||
env = Envelope(
|
||
request_id=42,
|
||
message_type=MessageType.REQUEST,
|
||
method="cap.request",
|
||
)
|
||
|
||
err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限")
|
||
assert err_resp.error is not None
|
||
assert err_resp.error["code"] == "E_UNAUTHORIZED"
|
||
assert err_resp.error["message"] == "没有权限"
|
||
|
||
def test_msgpack_codec(self):
|
||
"""MsgPack 编解码"""
|
||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||
|
||
codec = MsgPackCodec()
|
||
env = Envelope(
|
||
request_id=100,
|
||
message_type=MessageType.REQUEST,
|
||
method="test.method",
|
||
payload={"key": "value", "number": 42},
|
||
)
|
||
|
||
# 编码
|
||
data = codec.encode_envelope(env)
|
||
assert isinstance(data, bytes)
|
||
|
||
# 解码
|
||
decoded = codec.decode_envelope(data)
|
||
assert decoded.request_id == 100
|
||
assert decoded.method == "test.method"
|
||
assert decoded.payload["key"] == "value"
|
||
assert decoded.payload["number"] == 42
|
||
|
||
def test_json_codec(self):
|
||
"""JSON 编解码已移除,仅保留 MsgPack"""
|
||
pass
|
||
|
||
def test_request_id_generator(self):
|
||
"""请求 ID 生成器单调递增"""
|
||
from src.plugin_runtime.protocol.envelope import RequestIdGenerator
|
||
|
||
gen = RequestIdGenerator()
|
||
ids = [gen.next() for _ in range(100)]
|
||
assert ids == list(range(1, 101))
|
||
|
||
def test_error_codes(self):
|
||
"""错误码枚举"""
|
||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||
|
||
err = RPCError(ErrorCode.E_TIMEOUT, "请求超时")
|
||
assert err.code == ErrorCode.E_TIMEOUT
|
||
assert "E_TIMEOUT" in str(err)
|
||
|
||
# 序列化/反序列化
|
||
d = err.to_dict()
|
||
err2 = RPCError.from_dict(d)
|
||
assert err2.code == ErrorCode.E_TIMEOUT
|
||
|
||
|
||
# ─── 传输层测试 ───────────────────────────────────────────
|
||
|
||
class TestTransport:
|
||
"""传输层测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_uds_connection_framing(self):
|
||
"""UDS 分帧协议测试"""
|
||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||
|
||
server = UDSTransportServer()
|
||
received = asyncio.Event()
|
||
received_data = []
|
||
|
||
async def handler(conn):
|
||
data = await conn.recv_frame()
|
||
received_data.append(data)
|
||
await conn.send_frame(b"pong")
|
||
received.set()
|
||
|
||
await server.start(handler)
|
||
address = server.get_address()
|
||
|
||
client = UDSTransportClient(address)
|
||
conn = await client.connect()
|
||
await conn.send_frame(b"ping")
|
||
|
||
# 等待服务端处理
|
||
await asyncio.wait_for(received.wait(), timeout=5.0)
|
||
assert received_data[0] == b"ping"
|
||
|
||
# 接收服务端回复
|
||
resp = await conn.recv_frame()
|
||
assert resp == b"pong"
|
||
|
||
await conn.close()
|
||
await server.stop()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tcp_connection_framing(self):
|
||
"""TCP 分帧协议测试"""
|
||
from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient
|
||
|
||
server = TCPTransportServer()
|
||
received = asyncio.Event()
|
||
received_data = []
|
||
|
||
async def handler(conn):
|
||
data = await conn.recv_frame()
|
||
received_data.append(data)
|
||
await conn.send_frame(b"tcp_pong")
|
||
received.set()
|
||
|
||
await server.start(handler)
|
||
address = server.get_address()
|
||
host, port = address.split(":")
|
||
|
||
client = TCPTransportClient(host, int(port))
|
||
conn = await client.connect()
|
||
await conn.send_frame(b"tcp_ping")
|
||
|
||
await asyncio.wait_for(received.wait(), timeout=5.0)
|
||
assert received_data[0] == b"tcp_ping"
|
||
|
||
resp = await conn.recv_frame()
|
||
assert resp == b"tcp_pong"
|
||
|
||
await conn.close()
|
||
await server.stop()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_transport_factory(self):
|
||
"""传输工厂测试"""
|
||
from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client
|
||
|
||
server = create_transport_server()
|
||
assert server is not None
|
||
|
||
# UDS 路径
|
||
client = create_transport_client("/tmp/test.sock")
|
||
assert client is not None
|
||
|
||
# TCP 地址
|
||
client = create_transport_client("127.0.0.1:9999")
|
||
assert client is not None
|
||
|
||
|
||
# ─── Host 层测试 ──────────────────────────────────────────
|
||
|
||
class TestHost:
|
||
"""Host 端基础设施测试"""
|
||
|
||
def test_policy_engine(self):
|
||
"""策略引擎测试"""
|
||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||
|
||
engine = PolicyEngine()
|
||
|
||
# 注册插件
|
||
token = engine.register_plugin(
|
||
plugin_id="test_plugin",
|
||
generation=1,
|
||
capabilities=["send.text", "db.query"],
|
||
)
|
||
|
||
assert token.plugin_id == "test_plugin"
|
||
assert "send.text" in token.capabilities
|
||
|
||
# 能力检查
|
||
ok, _ = engine.check_capability("test_plugin", "send.text")
|
||
assert ok
|
||
|
||
ok, reason = engine.check_capability("test_plugin", "llm.generate")
|
||
assert not ok
|
||
assert "未获授权" in reason
|
||
|
||
# 未注册插件
|
||
ok, reason = engine.check_capability("unknown", "send.text")
|
||
assert not ok
|
||
|
||
ok, reason = engine.check_capability("test_plugin", "send.text", generation=2)
|
||
assert not ok
|
||
assert "generation 不匹配" in reason
|
||
|
||
def test_circuit_breaker_removed(self):
|
||
"""熔断器已移除,验证 supervisor 不依赖它"""
|
||
pass
|
||
|
||
def test_circuit_breaker_registry_removed(self):
|
||
"""熔断器注册表已移除"""
|
||
pass
|
||
|
||
|
||
# ─── SDK 测试 ─────────────────────────────────────────────
|
||
|
||
class TestSDK:
|
||
"""SDK 框架测试"""
|
||
|
||
def test_component_decorators(self):
|
||
"""组件装饰器测试"""
|
||
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
|
||
from maibot_sdk.types import ActivationType, EventType
|
||
|
||
class TestPlugin(MaiBotPlugin):
|
||
@Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"])
|
||
async def handle_greet(self, **kwargs):
|
||
return True, "ok"
|
||
|
||
@Command("echo", pattern=r"^/echo")
|
||
async def handle_echo(self, **kwargs):
|
||
return True, "echoed", 2
|
||
|
||
@Tool("search", parameters={"query": {"type": "string"}})
|
||
async def handle_search(self, **kwargs):
|
||
return {"result": "found"}
|
||
|
||
@EventHandler("on_start", event_type=EventType.ON_START)
|
||
async def handle_start(self, **kwargs):
|
||
return True, False, "started"
|
||
|
||
plugin = TestPlugin()
|
||
components = plugin.get_components()
|
||
|
||
assert len(components) == 4
|
||
|
||
names = {c["name"] for c in components}
|
||
assert "greet" in names
|
||
assert "echo" in names
|
||
assert "search" in names
|
||
assert "on_start" in names
|
||
|
||
types = {c["type"] for c in components}
|
||
assert "action" in types
|
||
assert "command" in types
|
||
assert "tool" in types
|
||
assert "event_handler" in types
|
||
|
||
def test_plugin_context_not_initialized(self):
|
||
"""未初始化上下文时应报错"""
|
||
from maibot_sdk import MaiBotPlugin
|
||
|
||
plugin = MaiBotPlugin()
|
||
with pytest.raises(RuntimeError, match="尚未初始化"):
|
||
_ = plugin.ctx
|
||
|
||
def test_plugin_context_injection(self):
|
||
"""上下文注入测试"""
|
||
from maibot_sdk import MaiBotPlugin
|
||
from maibot_sdk.context import PluginContext
|
||
|
||
plugin = MaiBotPlugin()
|
||
ctx = PluginContext(plugin_id="test")
|
||
plugin._set_context(ctx)
|
||
|
||
assert plugin.ctx.plugin_id == "test"
|
||
assert plugin.ctx.send is not None
|
||
assert plugin.ctx.db is not None
|
||
assert plugin.ctx.llm is not None
|
||
assert plugin.ctx.config is not None
|
||
|
||
|
||
# ─── 端到端集成测试 ────────────────────────────────────────
|
||
|
||
class TestE2E:
|
||
"""端到端集成测试(Host + Runner 通信)"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_handshake(self):
|
||
"""Host-Runner 握手流程测试"""
|
||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||
|
||
import secrets
|
||
import tempfile
|
||
import os
|
||
|
||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
|
||
session_token = secrets.token_hex(16)
|
||
codec = MsgPackCodec()
|
||
handshake_done = asyncio.Event()
|
||
server_result = {}
|
||
|
||
async def server_handler(conn):
|
||
# 接收握手
|
||
data = await conn.recv_frame()
|
||
env = codec.decode_envelope(data)
|
||
assert env.method == "runner.hello"
|
||
|
||
hello = HelloPayload.model_validate(env.payload)
|
||
assert hello.session_token == session_token
|
||
|
||
# 发送响应
|
||
resp_payload = HelloResponsePayload(
|
||
accepted=True,
|
||
host_version="1.0",
|
||
assigned_generation=1,
|
||
)
|
||
resp = env.make_response(payload=resp_payload.model_dump())
|
||
await conn.send_frame(codec.encode_envelope(resp))
|
||
|
||
server_result["runner_id"] = hello.runner_id
|
||
handshake_done.set()
|
||
|
||
# 保持连接一会儿
|
||
await asyncio.sleep(1.0)
|
||
|
||
server = UDSTransportServer(socket_path=socket_path)
|
||
await server.start(server_handler)
|
||
|
||
# 客户端握手
|
||
client = UDSTransportClient(socket_path)
|
||
conn = await client.connect()
|
||
|
||
hello = HelloPayload(
|
||
runner_id="test-runner",
|
||
sdk_version="1.0.0",
|
||
session_token=session_token,
|
||
)
|
||
env = Envelope(
|
||
request_id=1,
|
||
message_type=MessageType.REQUEST,
|
||
method="runner.hello",
|
||
payload=hello.model_dump(),
|
||
)
|
||
await conn.send_frame(codec.encode_envelope(env))
|
||
|
||
resp_data = await conn.recv_frame()
|
||
resp = codec.decode_envelope(resp_data)
|
||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||
|
||
assert resp_payload.accepted
|
||
assert resp_payload.assigned_generation == 1
|
||
|
||
await asyncio.wait_for(handshake_done.wait(), timeout=5.0)
|
||
assert server_result["runner_id"] == "test-runner"
|
||
|
||
await conn.close()
|
||
await server.stop()
|
||
|
||
|
||
# ─── Manifest 校验测试 ─────────────────────────────────────
|
||
|
||
class TestManifestValidator:
|
||
"""Manifest 校验器测试"""
|
||
|
||
def test_valid_manifest(self):
|
||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||
|
||
validator = ManifestValidator()
|
||
manifest = {
|
||
"manifest_version": 1,
|
||
"name": "test_plugin",
|
||
"version": "1.0.0",
|
||
"description": "测试插件",
|
||
"author": "test",
|
||
}
|
||
assert validator.validate(manifest) is True
|
||
assert len(validator.errors) == 0
|
||
|
||
def test_missing_required_fields(self):
|
||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||
|
||
validator = ManifestValidator()
|
||
manifest = {"manifest_version": 1}
|
||
assert validator.validate(manifest) is False
|
||
assert len(validator.errors) >= 4 # name, version, description, author
|
||
|
||
def test_unsupported_manifest_version(self):
|
||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||
|
||
validator = ManifestValidator()
|
||
manifest = {
|
||
"manifest_version": 999,
|
||
"name": "test",
|
||
"version": "1.0",
|
||
"description": "d",
|
||
"author": "a",
|
||
}
|
||
assert validator.validate(manifest) is False
|
||
assert any("manifest_version" in e for e in validator.errors)
|
||
|
||
def test_host_version_compatibility(self):
|
||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||
|
||
validator = ManifestValidator(host_version="0.8.5")
|
||
manifest = {
|
||
"name": "test",
|
||
"version": "1.0",
|
||
"description": "d",
|
||
"author": "a",
|
||
"host_application": {"min_version": "0.9.0"},
|
||
}
|
||
assert validator.validate(manifest) is False
|
||
assert any("Host 版本不兼容" in e for e in validator.errors)
|
||
|
||
def test_recommended_fields_warning(self):
|
||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||
|
||
validator = ManifestValidator()
|
||
manifest = {
|
||
"name": "test",
|
||
"version": "1.0",
|
||
"description": "d",
|
||
"author": "a",
|
||
}
|
||
validator.validate(manifest)
|
||
assert len(validator.warnings) >= 3 # license, keywords, categories
|
||
|
||
|
||
class TestVersionComparator:
|
||
"""版本号比较器测试"""
|
||
|
||
def test_normalize(self):
|
||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||
|
||
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
|
||
assert VersionComparator.normalize_version("1.2") == "1.2.0"
|
||
assert VersionComparator.normalize_version("") == "0.0.0"
|
||
|
||
def test_compare(self):
|
||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||
|
||
assert VersionComparator.compare("0.8.0", "0.8.0") == 0
|
||
assert VersionComparator.compare("0.8.0", "0.9.0") == -1
|
||
assert VersionComparator.compare("1.0.0", "0.9.0") == 1
|
||
|
||
def test_is_in_range(self):
|
||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||
|
||
ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0")
|
||
assert ok
|
||
ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0")
|
||
assert not ok
|
||
ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0")
|
||
assert not ok
|
||
|
||
|
||
# ─── 依赖解析测试 ──────────────────────────────────────────
|
||
|
||
class TestDependencyResolution:
|
||
"""插件依赖解析测试"""
|
||
|
||
def test_topological_sort(self):
|
||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||
|
||
loader = PluginLoader()
|
||
candidates = {
|
||
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
|
||
"auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"),
|
||
"api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"),
|
||
}
|
||
|
||
order, failed = loader._resolve_dependencies(candidates)
|
||
assert len(failed) == 0
|
||
assert order.index("core") < order.index("auth")
|
||
assert order.index("auth") < order.index("api")
|
||
|
||
def test_missing_dependency(self):
|
||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||
|
||
loader = PluginLoader()
|
||
candidates = {
|
||
"plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"),
|
||
}
|
||
|
||
order, failed = loader._resolve_dependencies(candidates)
|
||
assert "plugin_a" in failed
|
||
assert "缺少依赖" in failed["plugin_a"]
|
||
|
||
def test_circular_dependency(self):
|
||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||
|
||
loader = PluginLoader()
|
||
candidates = {
|
||
"a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"),
|
||
"b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"),
|
||
}
|
||
|
||
order, failed = loader._resolve_dependencies(candidates)
|
||
assert len(failed) >= 1 # 至少一个循环插件被标记
|
||
|
||
|
||
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
||
|
||
class TestComponentRegistry:
|
||
"""Host-side 组件注册表测试"""
|
||
|
||
def test_register_and_query(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("greet", "action", "plugin_a", {
|
||
"description": "打招呼",
|
||
"activation_type": "keyword",
|
||
"activation_keywords": ["hi"],
|
||
})
|
||
reg.register_component("help", "command", "plugin_a", {
|
||
"command_pattern": r"^/help",
|
||
})
|
||
reg.register_component("search", "tool", "plugin_b", {
|
||
"description": "搜索",
|
||
})
|
||
|
||
stats = reg.get_stats()
|
||
assert stats["total"] == 3
|
||
assert stats["action"] == 1
|
||
assert stats["command"] == 1
|
||
assert stats["tool"] == 1
|
||
|
||
def test_query_by_type(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("a1", "action", "p1", {})
|
||
reg.register_component("a2", "action", "p2", {})
|
||
|
||
actions = reg.get_components_by_type("action")
|
||
assert len(actions) == 2
|
||
|
||
def test_find_command_by_text(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("help", "command", "p1", {
|
||
"command_pattern": r"^/help",
|
||
})
|
||
reg.register_component("echo", "command", "p1", {
|
||
"command_pattern": r"^/echo\s",
|
||
})
|
||
|
||
match = reg.find_command_by_text("/help me")
|
||
assert match is not None
|
||
assert match.name == "help"
|
||
|
||
match = reg.find_command_by_text("/echo hello")
|
||
assert match is not None
|
||
assert match.name == "echo"
|
||
|
||
match = reg.find_command_by_text("no match")
|
||
assert match is None
|
||
|
||
def test_enable_disable(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("a1", "action", "p1", {})
|
||
reg.set_component_enabled("p1.a1", False)
|
||
|
||
actions = reg.get_components_by_type("action", enabled_only=True)
|
||
assert len(actions) == 0
|
||
|
||
actions = reg.get_components_by_type("action", enabled_only=False)
|
||
assert len(actions) == 1
|
||
|
||
def test_remove_by_plugin(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("a1", "action", "p1", {})
|
||
reg.register_component("c1", "command", "p1", {})
|
||
reg.register_component("a2", "action", "p2", {})
|
||
|
||
removed = reg.remove_components_by_plugin("p1")
|
||
assert removed == 2
|
||
assert reg.get_stats()["total"] == 1
|
||
|
||
def test_event_handlers_sorted_by_weight(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("h_low", "event_handler", "p1", {
|
||
"event_type": "on_message", "weight": 10,
|
||
})
|
||
reg.register_component("h_high", "event_handler", "p2", {
|
||
"event_type": "on_message", "weight": 100,
|
||
})
|
||
|
||
handlers = reg.get_event_handlers("on_message")
|
||
assert handlers[0].name == "h_high"
|
||
assert handlers[1].name == "h_low"
|
||
|
||
def test_tools_for_llm(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("search", "tool", "p1", {
|
||
"description": "搜索工具",
|
||
"parameters_raw": {"query": {"type": "string"}},
|
||
})
|
||
|
||
tools = reg.get_tools_for_llm()
|
||
assert len(tools) == 1
|
||
assert tools[0]["name"] == "p1.search"
|
||
assert tools[0]["parameters"]["query"]["type"] == "string"
|
||
|
||
|
||
# ─── EventDispatcher 测试 ─────────────────────────────────
|
||
|
||
class TestEventDispatcher:
|
||
"""Host-side 事件分发器测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dispatch_non_blocking(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("h1", "event_handler", "p1", {
|
||
"event_type": "on_start",
|
||
"weight": 0,
|
||
"intercept_message": False,
|
||
})
|
||
|
||
dispatcher = EventDispatcher(reg)
|
||
call_log = []
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
call_log.append((plugin_id, comp_name))
|
||
return {"success": True, "continue_processing": True}
|
||
|
||
should_continue, modified = await dispatcher.dispatch_event(
|
||
"on_start", mock_invoke
|
||
)
|
||
assert should_continue
|
||
# 非阻塞分发是异步的,等一下让 task 完成
|
||
await asyncio.sleep(0.1)
|
||
assert len(call_log) == 1
|
||
assert call_log[0] == ("p1", "h1")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dispatch_intercepting(self):
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("filter", "event_handler", "p1", {
|
||
"event_type": "on_message_pre_process",
|
||
"weight": 100,
|
||
"intercept_message": True,
|
||
})
|
||
|
||
dispatcher = EventDispatcher(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
return {
|
||
"success": True,
|
||
"continue_processing": False,
|
||
"modified_message": {"plain_text": "filtered"},
|
||
}
|
||
|
||
should_continue, modified = await dispatcher.dispatch_event(
|
||
"on_message_pre_process", mock_invoke, message={"plain_text": "hello"}
|
||
)
|
||
assert not should_continue
|
||
assert modified is not None
|
||
assert modified["plain_text"] == "filtered"
|
||
|
||
|
||
class TestEventBus:
|
||
"""核心事件总线与 IPC 桥接测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_bridge_preserves_modified_message(self, monkeypatch):
|
||
import types
|
||
|
||
fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model")
|
||
fake_message_data_model.ReplyContentType = object
|
||
fake_message_data_model.ReplyContent = object
|
||
fake_message_data_model.ForwardNode = object
|
||
fake_message_data_model.ReplySetModel = object
|
||
monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model)
|
||
|
||
from src.core.event_bus import EventBus
|
||
from src.core.types import EventType, MaiMessages
|
||
from src.plugin_runtime import integration as integration_module
|
||
|
||
bus = EventBus()
|
||
|
||
async def noop_handler(message):
|
||
return True, message
|
||
|
||
bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True)
|
||
|
||
class FakeManager:
|
||
is_running = True
|
||
|
||
async def bridge_event(self, event_type_value, message_dict=None, extra_args=None):
|
||
assert event_type_value == EventType.ON_MESSAGE.value
|
||
return True, {"plain_text": "modified by ipc"}
|
||
|
||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
|
||
|
||
original = MaiMessages(plain_text="original")
|
||
continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original)
|
||
|
||
assert continue_flag is True
|
||
assert modified is not None
|
||
assert modified.plain_text == "modified by ipc"
|
||
assert original.plain_text == "original"
|
||
|
||
|
||
# ─── MaiMessages 测试 ─────────────────────────────────────
|
||
|
||
class TestMaiMessages:
|
||
"""统一消息模型测试"""
|
||
|
||
def test_create_and_serialize(self):
|
||
from maibot_sdk.messages import MaiMessages, MessageSegment
|
||
|
||
msg = MaiMessages(
|
||
message_segments=[MessageSegment(type="text", data={"text": "hello"})],
|
||
plain_text="hello",
|
||
stream_id="stream_1",
|
||
)
|
||
|
||
d = msg.to_rpc_dict()
|
||
assert d["plain_text"] == "hello"
|
||
assert len(d["message_segments"]) == 1
|
||
|
||
msg2 = MaiMessages.from_rpc_dict(d)
|
||
assert msg2.plain_text == "hello"
|
||
|
||
def test_deepcopy(self):
|
||
from maibot_sdk.messages import MaiMessages
|
||
|
||
msg = MaiMessages(plain_text="original")
|
||
msg2 = msg.deepcopy()
|
||
msg2.plain_text = "modified"
|
||
assert msg.plain_text == "original"
|
||
|
||
def test_modify_flags(self):
|
||
from maibot_sdk.messages import MaiMessages
|
||
from maibot_sdk.types import ModifyFlag
|
||
|
||
msg = MaiMessages(plain_text="hello")
|
||
assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT)
|
||
|
||
msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False)
|
||
assert not msg.modify_prompt("new prompt")
|
||
assert msg.llm_prompt is None
|
||
|
||
assert msg.modify_response("new response")
|
||
assert msg.llm_response_content == "new response"
|
||
|
||
|
||
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
||
|
||
class TestWorkflowExecutor:
|
||
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_pipeline_completes(self):
|
||
"""无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
return {"hook_result": "continue"}
|
||
|
||
result, final_msg, ctx = await executor.execute(
|
||
mock_invoke,
|
||
message={"plain_text": "test"},
|
||
)
|
||
assert result.status == "completed"
|
||
assert result.return_message == "workflow completed"
|
||
assert len(ctx.timings) == 6 # 6 stages
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_blocking_hook_modifies_message(self):
|
||
"""blocking hook 可以修改消息"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("upper", "workflow_step", "p1", {
|
||
"stage": "pre_process",
|
||
"priority": 10,
|
||
"blocking": True,
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
msg = args.get("message", {})
|
||
return {
|
||
"hook_result": "continue",
|
||
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
|
||
}
|
||
|
||
result, final_msg, ctx = await executor.execute(
|
||
mock_invoke,
|
||
message={"plain_text": "hello"},
|
||
)
|
||
assert result.status == "completed"
|
||
assert final_msg["plain_text"] == "HELLO"
|
||
assert len(ctx.modification_log) == 1
|
||
assert ctx.modification_log[0].stage == "pre_process"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_abort_stops_pipeline(self):
|
||
"""HookResult.ABORT 立即终止 pipeline"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("blocker", "workflow_step", "p1", {
|
||
"stage": "pre_process",
|
||
"priority": 10,
|
||
"blocking": True,
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
return {"hook_result": "abort"}
|
||
|
||
result, _, ctx = await executor.execute(
|
||
mock_invoke,
|
||
message={"plain_text": "test"},
|
||
)
|
||
assert result.status == "aborted"
|
||
assert result.stopped_at == "pre_process"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skip_stage(self):
|
||
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
# high-priority hook 返回 skip_stage
|
||
reg.register_component("skipper", "workflow_step", "p1", {
|
||
"stage": "ingress",
|
||
"priority": 100,
|
||
"blocking": True,
|
||
})
|
||
# low-priority hook 不应被执行
|
||
reg.register_component("checker", "workflow_step", "p2", {
|
||
"stage": "ingress",
|
||
"priority": 1,
|
||
"blocking": True,
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
call_log = []
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
call_log.append(comp_name)
|
||
if comp_name == "skipper":
|
||
return {"hook_result": "skip_stage"}
|
||
return {"hook_result": "continue"}
|
||
|
||
result, _, _ = await executor.execute(
|
||
mock_invoke, message={"plain_text": "test"}
|
||
)
|
||
assert result.status == "completed"
|
||
# 只有 skipper 被调用,checker 被跳过
|
||
assert call_log == ["skipper"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pre_filter(self):
|
||
"""filter 条件不匹配时跳过 hook"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("only_dm", "workflow_step", "p1", {
|
||
"stage": "ingress",
|
||
"priority": 10,
|
||
"blocking": True,
|
||
"filter": {"chat_type": "direct"},
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
call_log = []
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
call_log.append(comp_name)
|
||
return {"hook_result": "continue"}
|
||
|
||
# 不匹配 filter —— hook 不应被调用
|
||
await executor.execute(
|
||
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
|
||
)
|
||
assert not call_log
|
||
|
||
# 匹配 filter —— hook 应被调用
|
||
await executor.execute(
|
||
mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}
|
||
)
|
||
assert call_log == ["only_dm"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_error_policy_skip(self):
|
||
"""error_policy=skip 时跳过失败的 hook 继续执行"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("failer", "workflow_step", "p1", {
|
||
"stage": "ingress",
|
||
"priority": 100,
|
||
"blocking": True,
|
||
"error_policy": "skip",
|
||
})
|
||
reg.register_component("ok_step", "workflow_step", "p2", {
|
||
"stage": "ingress",
|
||
"priority": 1,
|
||
"blocking": True,
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
call_log = []
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
call_log.append(comp_name)
|
||
if comp_name == "failer":
|
||
raise RuntimeError("boom")
|
||
return {"hook_result": "continue"}
|
||
|
||
result, _, ctx = await executor.execute(
|
||
mock_invoke, message={"plain_text": "test"}
|
||
)
|
||
assert result.status == "completed"
|
||
assert "failer" in call_log
|
||
assert "ok_step" in call_log
|
||
assert any("boom" in e for e in ctx.errors)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_error_policy_abort(self):
|
||
"""error_policy=abort(默认)时 pipeline 失败"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("failer", "workflow_step", "p1", {
|
||
"stage": "ingress",
|
||
"priority": 10,
|
||
"blocking": True,
|
||
# error_policy defaults to "abort"
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
raise RuntimeError("fatal")
|
||
|
||
result, _, ctx = await executor.execute(
|
||
mock_invoke, message={"plain_text": "test"}
|
||
)
|
||
assert result.status == "failed"
|
||
assert result.stopped_at == "ingress"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_nonblocking_hooks_concurrent(self):
|
||
"""non-blocking hook 并发执行,不修改消息"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
for i in range(3):
|
||
reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", {
|
||
"stage": "post_process",
|
||
"priority": 0,
|
||
"blocking": False,
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
call_log = []
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
call_log.append(comp_name)
|
||
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
||
|
||
result, final_msg, _ = await executor.execute(
|
||
mock_invoke, message={"plain_text": "original"}
|
||
)
|
||
# non-blocking 的 modified_message 被忽略
|
||
assert final_msg["plain_text"] == "original"
|
||
# 给异步 task 时间完成
|
||
await asyncio.sleep(0.1)
|
||
assert result.status == "completed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_command_routing(self):
|
||
"""PLAN 阶段内置命令路由"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
reg.register_component("help", "command", "p1", {
|
||
"command_pattern": r"^/help",
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
if comp_name == "help":
|
||
return {"output": "帮助信息"}
|
||
return {"hook_result": "continue"}
|
||
|
||
result, _, ctx = await executor.execute(
|
||
mock_invoke, message={"plain_text": "/help topic"}
|
||
)
|
||
assert result.status == "completed"
|
||
assert ctx.matched_command == "p1.help"
|
||
cmd_result = ctx.get_stage_output("plan", "command_result")
|
||
assert cmd_result is not None
|
||
assert cmd_result["output"] == "帮助信息"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_stage_outputs(self):
|
||
"""stage_outputs 数据在阶段间传递"""
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||
|
||
reg = ComponentRegistry()
|
||
# ingress 阶段写入数据
|
||
reg.register_component("writer", "workflow_step", "p1", {
|
||
"stage": "ingress",
|
||
"priority": 10,
|
||
"blocking": True,
|
||
})
|
||
# pre_process 阶段读取数据
|
||
reg.register_component("reader", "workflow_step", "p2", {
|
||
"stage": "pre_process",
|
||
"priority": 10,
|
||
"blocking": True,
|
||
})
|
||
executor = WorkflowExecutor(reg)
|
||
|
||
async def mock_invoke(plugin_id, comp_name, args):
|
||
if comp_name == "writer":
|
||
return {
|
||
"hook_result": "continue",
|
||
"stage_output": {"parsed_intent": "greeting"},
|
||
}
|
||
if comp_name == "reader":
|
||
# 验证 stage_outputs 被传递过来
|
||
outputs = args.get("stage_outputs", {})
|
||
ingress_data = outputs.get("ingress", {})
|
||
assert ingress_data.get("parsed_intent") == "greeting"
|
||
return {"hook_result": "continue"}
|
||
return {"hook_result": "continue"}
|
||
|
||
result, _, ctx = await executor.execute(
|
||
mock_invoke, message={"plain_text": "hi"}
|
||
)
|
||
assert result.status == "completed"
|
||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||
|
||
|
||
class TestRPCServer:
|
||
"""RPC Server 代际保护测试"""
|
||
|
||
def test_ignore_stale_generation_response(self):
|
||
from src.plugin_runtime.host.rpc_server import RPCServer
|
||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||
|
||
class DummyTransport:
|
||
async def start(self, handler):
|
||
return None
|
||
|
||
async def stop(self):
|
||
return None
|
||
|
||
def get_address(self):
|
||
return "dummy"
|
||
|
||
server = RPCServer(transport=DummyTransport())
|
||
server._runner_generation = 2
|
||
|
||
loop = asyncio.new_event_loop()
|
||
try:
|
||
future = loop.create_future()
|
||
server._pending_requests[1] = future
|
||
|
||
stale_response = Envelope(
|
||
request_id=1,
|
||
message_type=MessageType.RESPONSE,
|
||
method="plugin.health",
|
||
generation=1,
|
||
payload={"healthy": True},
|
||
)
|
||
server._handle_response(stale_response)
|
||
|
||
assert not future.done()
|
||
assert 1 in server._pending_requests
|
||
finally:
|
||
loop.close()
|
||
|
||
|
||
class TestSupervisor:
|
||
"""Supervisor 生命周期边界测试"""
|
||
|
||
@staticmethod
|
||
def _build_register_payload(plugin_id: str = "plugin_a"):
|
||
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
|
||
|
||
return RegisterComponentsPayload(
|
||
plugin_id=plugin_id,
|
||
plugin_version="1.0.0",
|
||
components=[
|
||
ComponentDeclaration(
|
||
name="handler",
|
||
component_type="event_handler",
|
||
plugin_id=plugin_id,
|
||
metadata={"event_type": "on_message"},
|
||
)
|
||
],
|
||
capabilities_required=["send.text"],
|
||
)
|
||
|
||
@staticmethod
|
||
def _make_process(pid: int):
|
||
class FakeProcess:
|
||
def __init__(self):
|
||
self.pid = pid
|
||
self.returncode = None
|
||
self.stdout = None
|
||
self.stderr = None
|
||
self.terminated = False
|
||
self.killed = False
|
||
|
||
def terminate(self):
|
||
self.terminated = True
|
||
self.returncode = 0
|
||
|
||
def kill(self):
|
||
self.killed = True
|
||
self.returncode = -9
|
||
|
||
async def wait(self):
|
||
return self.returncode
|
||
|
||
return FakeProcess()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reload_waits_for_target_generation(self, monkeypatch):
|
||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||
from src.plugin_runtime.protocol.envelope import HealthPayload
|
||
|
||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||
old_process = self._make_process(1)
|
||
new_process = self._make_process(2)
|
||
|
||
class FakeRPCServer:
|
||
def __init__(self):
|
||
self.runner_generation = 1
|
||
self.is_connected = True
|
||
self.session_token = "fake-token"
|
||
|
||
def reset_session_token(self):
|
||
self.session_token = "new-fake-token"
|
||
return self.session_token
|
||
|
||
def restore_session_token(self, token):
|
||
self.session_token = token
|
||
|
||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
||
assert self.runner_generation == 2
|
||
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
||
|
||
supervisor._rpc_server = FakeRPCServer()
|
||
supervisor._runner_process = old_process
|
||
|
||
async def fake_spawn_runner():
|
||
supervisor._runner_process = new_process
|
||
|
||
async def advance_generation():
|
||
await asyncio.sleep(0.01)
|
||
supervisor._rpc_server.runner_generation = 2
|
||
|
||
asyncio.create_task(advance_generation())
|
||
|
||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||
|
||
await supervisor.reload_plugins("test")
|
||
|
||
assert supervisor._runner_process is new_process
|
||
assert old_process.terminated is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reload_restores_runtime_state_on_failure(self, monkeypatch):
|
||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||
|
||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||
old_process = self._make_process(1)
|
||
new_process = self._make_process(2)
|
||
old_reg = self._build_register_payload()
|
||
|
||
supervisor._runner_process = old_process
|
||
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
|
||
supervisor._rebuild_runtime_state()
|
||
|
||
class FakeRPCServer:
|
||
def __init__(self):
|
||
self.runner_generation = 1
|
||
self.is_connected = True
|
||
self.session_token = "fake-token"
|
||
|
||
def reset_session_token(self):
|
||
self.session_token = "new-fake-token"
|
||
return self.session_token
|
||
|
||
def restore_session_token(self, token):
|
||
self.session_token = token
|
||
|
||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
||
raise RuntimeError("new runner unhealthy")
|
||
|
||
supervisor._rpc_server = FakeRPCServer()
|
||
|
||
async def fake_spawn_runner():
|
||
supervisor._runner_process = new_process
|
||
supervisor._rpc_server.runner_generation = 2
|
||
|
||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||
|
||
await supervisor.reload_plugins("test")
|
||
|
||
assert supervisor._runner_process is old_process
|
||
assert old_reg.plugin_id in supervisor._registered_plugins
|
||
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_attach_stderr_drain_drains_stream(self):
|
||
"""_attach_stderr_drain 为 stderr 创建排空任务,读完后任务自动完成。"""
|
||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||
|
||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||
|
||
stderr = asyncio.StreamReader()
|
||
stderr.feed_data(b"fatal startup error\n")
|
||
stderr.feed_eof()
|
||
|
||
# stdout=None 模拟新架构(不再捕获 stdout)
|
||
process = SimpleNamespace(pid=99, stdout=None, stderr=stderr)
|
||
supervisor._attach_stderr_drain(process)
|
||
|
||
# 给 drain task 足够时间消费完数据
|
||
await asyncio.sleep(0.05)
|
||
|
||
assert supervisor._stderr_drain_task is None or supervisor._stderr_drain_task.done()
|
||
|
||
|
||
class TestIntegration:
|
||
"""运行时集成层启动/清理测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch):
|
||
from src.plugin_runtime import integration as integration_module
|
||
|
||
instances = []
|
||
|
||
class FakeCapabilityService:
|
||
def register_capability(self, name, impl):
|
||
return None
|
||
|
||
class FakeSupervisor:
|
||
def __init__(self, plugin_dirs=None, socket_path=None):
|
||
self.plugin_dirs = plugin_dirs or []
|
||
self.capability_service = FakeCapabilityService()
|
||
self.stopped = False
|
||
instances.append(self)
|
||
|
||
async def start(self):
|
||
if len(instances) == 2 and self is instances[1]:
|
||
raise RuntimeError("boom")
|
||
|
||
async def stop(self):
|
||
self.stopped = True
|
||
|
||
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]))
|
||
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]))
|
||
|
||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
|
||
|
||
manager = integration_module.PluginRuntimeManager()
|
||
await manager.start()
|
||
|
||
assert manager.is_running is False
|
||
assert len(instances) == 2
|
||
assert instances[0].stopped is True
|