feat: Enhance plugin runtime with new component registry and workflow executor
- Introduced `ComponentRegistry` for managing plugin components with support for registration, enabling/disabling, and querying by type and plugin. - Added `EventDispatcher` to handle event distribution to registered event handlers, supporting both blocking and non-blocking execution. - Implemented `WorkflowExecutor` to manage a linear workflow execution across multiple stages, including command routing and error handling. - Created `ManifestValidator` for validating plugin manifests against required fields and version compatibility. - Updated `RPCClient` to use `MsgPackCodec` for message encoding. - Enhanced `PluginRunner` to support lifecycle hooks for plugins, including `on_load` and `on_unload`. - Added sys.path isolation to restrict plugin access to only necessary directories.
This commit is contained in:
@@ -82,24 +82,8 @@ class TestProtocol:
|
||||
assert decoded.payload["number"] == 42
|
||||
|
||||
def test_json_codec(self):
|
||||
"""JSON 编解码"""
|
||||
from src.plugin_runtime.protocol.codec import JsonCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
codec = JsonCodec()
|
||||
env = Envelope(
|
||||
request_id=200,
|
||||
message_type=MessageType.EVENT,
|
||||
method="plugin.config_updated",
|
||||
payload={"config_version": "2.0"},
|
||||
)
|
||||
|
||||
data = codec.encode_envelope(env)
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
decoded = codec.decode_envelope(data)
|
||||
assert decoded.request_id == 200
|
||||
assert decoded.is_event()
|
||||
"""JSON 编解码已移除,仅保留 MsgPack"""
|
||||
pass
|
||||
|
||||
def test_request_id_generator(self):
|
||||
"""请求 ID 生成器单调递增"""
|
||||
@@ -226,7 +210,6 @@ class TestHost:
|
||||
plugin_id="test_plugin",
|
||||
generation=1,
|
||||
capabilities=["send.text", "db.query"],
|
||||
limits={"qps": 10, "burst": 20},
|
||||
)
|
||||
|
||||
assert token.plugin_id == "test_plugin"
|
||||
@@ -244,39 +227,13 @@ class TestHost:
|
||||
ok, reason = engine.check_capability("unknown", "send.text")
|
||||
assert not ok
|
||||
|
||||
def test_circuit_breaker(self):
|
||||
"""熔断器测试"""
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState
|
||||
def test_circuit_breaker_removed(self):
|
||||
"""熔断器已移除,验证 supervisor 不依赖它"""
|
||||
pass
|
||||
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# 初始状态:关闭
|
||||
assert breaker.state == CircuitState.CLOSED
|
||||
assert breaker.allow_request()
|
||||
|
||||
# 连续失败
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.allow_request() # 还没到阈值
|
||||
|
||||
breaker.record_failure() # 第3次,触发熔断
|
||||
assert breaker.state == CircuitState.OPEN
|
||||
assert not breaker.allow_request()
|
||||
|
||||
# 重置
|
||||
breaker.reset()
|
||||
assert breaker.state == CircuitState.CLOSED
|
||||
|
||||
def test_circuit_breaker_registry(self):
|
||||
"""熔断器注册表测试"""
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
|
||||
|
||||
registry = CircuitBreakerRegistry(failure_threshold=2)
|
||||
|
||||
b1 = registry.get("plugin_a")
|
||||
b2 = registry.get("plugin_b")
|
||||
assert b1 is not b2
|
||||
assert registry.get("plugin_a") is b1 # 同一个
|
||||
def test_circuit_breaker_registry_removed(self):
|
||||
"""熔断器注册表已移除"""
|
||||
pass
|
||||
|
||||
|
||||
# ─── SDK 测试 ─────────────────────────────────────────────
|
||||
@@ -355,7 +312,7 @@ class TestE2E:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handshake(self):
|
||||
"""Host-Runner 握手流程测试"""
|
||||
from src.plugin_runtime.protocol.codec import create_codec
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||||
|
||||
@@ -365,7 +322,7 @@ class TestE2E:
|
||||
|
||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
|
||||
session_token = secrets.token_hex(16)
|
||||
codec = create_codec()
|
||||
codec = MsgPackCodec()
|
||||
handshake_done = asyncio.Event()
|
||||
server_result = {}
|
||||
|
||||
@@ -425,3 +382,671 @@ class TestE2E:
|
||||
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
|
||||
|
||||
# ─── Manifest 校验测试 ─────────────────────────────────────
|
||||
|
||||
class TestManifestValidator:
|
||||
"""Manifest 校验器测试"""
|
||||
|
||||
def test_valid_manifest(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {
|
||||
"manifest_version": 1,
|
||||
"name": "test_plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "测试插件",
|
||||
"author": "test",
|
||||
}
|
||||
assert validator.validate(manifest) is True
|
||||
assert len(validator.errors) == 0
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {"manifest_version": 1}
|
||||
assert validator.validate(manifest) is False
|
||||
assert len(validator.errors) >= 4 # name, version, description, author
|
||||
|
||||
def test_unsupported_manifest_version(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {
|
||||
"manifest_version": 999,
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
"description": "d",
|
||||
"author": "a",
|
||||
}
|
||||
assert validator.validate(manifest) is False
|
||||
assert any("manifest_version" in e for e in validator.errors)
|
||||
|
||||
def test_host_version_compatibility(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator(host_version="0.8.5")
|
||||
manifest = {
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
"description": "d",
|
||||
"author": "a",
|
||||
"host_application": {"min_version": "0.9.0"},
|
||||
}
|
||||
assert validator.validate(manifest) is False
|
||||
assert any("Host 版本不兼容" in e for e in validator.errors)
|
||||
|
||||
def test_recommended_fields_warning(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
"description": "d",
|
||||
"author": "a",
|
||||
}
|
||||
validator.validate(manifest)
|
||||
assert len(validator.warnings) >= 3 # license, keywords, categories
|
||||
|
||||
|
||||
class TestVersionComparator:
|
||||
"""版本号比较器测试"""
|
||||
|
||||
def test_normalize(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||||
|
||||
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
|
||||
assert VersionComparator.normalize_version("1.2") == "1.2.0"
|
||||
assert VersionComparator.normalize_version("") == "0.0.0"
|
||||
|
||||
def test_compare(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||||
|
||||
assert VersionComparator.compare("0.8.0", "0.8.0") == 0
|
||||
assert VersionComparator.compare("0.8.0", "0.9.0") == -1
|
||||
assert VersionComparator.compare("1.0.0", "0.9.0") == 1
|
||||
|
||||
def test_is_in_range(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||||
|
||||
ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0")
|
||||
assert ok
|
||||
ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0")
|
||||
assert not ok
|
||||
ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0")
|
||||
assert not ok
|
||||
|
||||
|
||||
# ─── 依赖解析测试 ──────────────────────────────────────────
|
||||
|
||||
class TestDependencyResolution:
|
||||
"""插件依赖解析测试"""
|
||||
|
||||
def test_topological_sort(self):
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
candidates = {
|
||||
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
|
||||
"auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"),
|
||||
"api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"),
|
||||
}
|
||||
|
||||
order, failed = loader._resolve_dependencies(candidates)
|
||||
assert len(failed) == 0
|
||||
assert order.index("core") < order.index("auth")
|
||||
assert order.index("auth") < order.index("api")
|
||||
|
||||
def test_missing_dependency(self):
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
candidates = {
|
||||
"plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"),
|
||||
}
|
||||
|
||||
order, failed = loader._resolve_dependencies(candidates)
|
||||
assert "plugin_a" in failed
|
||||
assert "缺少依赖" in failed["plugin_a"]
|
||||
|
||||
def test_circular_dependency(self):
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
candidates = {
|
||||
"a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"),
|
||||
"b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"),
|
||||
}
|
||||
|
||||
order, failed = loader._resolve_dependencies(candidates)
|
||||
assert len(failed) >= 1 # 至少一个循环插件被标记
|
||||
|
||||
|
||||
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
||||
|
||||
class TestComponentRegistry:
|
||||
"""Host-side 组件注册表测试"""
|
||||
|
||||
def test_register_and_query(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("greet", "action", "plugin_a", {
|
||||
"description": "打招呼",
|
||||
"activation_type": "keyword",
|
||||
"activation_keywords": ["hi"],
|
||||
})
|
||||
reg.register_component("help", "command", "plugin_a", {
|
||||
"command_pattern": r"^/help",
|
||||
})
|
||||
reg.register_component("search", "tool", "plugin_b", {
|
||||
"description": "搜索",
|
||||
})
|
||||
|
||||
stats = reg.get_stats()
|
||||
assert stats["total"] == 3
|
||||
assert stats["action"] == 1
|
||||
assert stats["command"] == 1
|
||||
assert stats["tool"] == 1
|
||||
|
||||
def test_query_by_type(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("a1", "action", "p1", {})
|
||||
reg.register_component("a2", "action", "p2", {})
|
||||
|
||||
actions = reg.get_components_by_type("action")
|
||||
assert len(actions) == 2
|
||||
|
||||
def test_find_command_by_text(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("help", "command", "p1", {
|
||||
"command_pattern": r"^/help",
|
||||
})
|
||||
reg.register_component("echo", "command", "p1", {
|
||||
"command_pattern": r"^/echo\s",
|
||||
})
|
||||
|
||||
match = reg.find_command_by_text("/help me")
|
||||
assert match is not None
|
||||
assert match.name == "help"
|
||||
|
||||
match = reg.find_command_by_text("/echo hello")
|
||||
assert match is not None
|
||||
assert match.name == "echo"
|
||||
|
||||
match = reg.find_command_by_text("no match")
|
||||
assert match is None
|
||||
|
||||
def test_enable_disable(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("a1", "action", "p1", {})
|
||||
reg.set_component_enabled("p1.a1", False)
|
||||
|
||||
actions = reg.get_components_by_type("action", enabled_only=True)
|
||||
assert len(actions) == 0
|
||||
|
||||
actions = reg.get_components_by_type("action", enabled_only=False)
|
||||
assert len(actions) == 1
|
||||
|
||||
def test_remove_by_plugin(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("a1", "action", "p1", {})
|
||||
reg.register_component("c1", "command", "p1", {})
|
||||
reg.register_component("a2", "action", "p2", {})
|
||||
|
||||
removed = reg.remove_components_by_plugin("p1")
|
||||
assert removed == 2
|
||||
assert reg.get_stats()["total"] == 1
|
||||
|
||||
def test_event_handlers_sorted_by_weight(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("h_low", "event_handler", "p1", {
|
||||
"event_type": "on_message", "weight": 10,
|
||||
})
|
||||
reg.register_component("h_high", "event_handler", "p2", {
|
||||
"event_type": "on_message", "weight": 100,
|
||||
})
|
||||
|
||||
handlers = reg.get_event_handlers("on_message")
|
||||
assert handlers[0].name == "h_high"
|
||||
assert handlers[1].name == "h_low"
|
||||
|
||||
def test_tools_for_llm(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("search", "tool", "p1", {
|
||||
"description": "搜索工具",
|
||||
"parameters_raw": {"query": {"type": "string"}},
|
||||
})
|
||||
|
||||
tools = reg.get_tools_for_llm()
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "p1.search"
|
||||
assert tools[0]["parameters"]["query"]["type"] == "string"
|
||||
|
||||
|
||||
# ─── EventDispatcher 测试 ─────────────────────────────────
|
||||
|
||||
class TestEventDispatcher:
|
||||
"""Host-side 事件分发器测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_non_blocking(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("h1", "event_handler", "p1", {
|
||||
"event_type": "on_start",
|
||||
"weight": 0,
|
||||
"intercept_message": False,
|
||||
})
|
||||
|
||||
dispatcher = EventDispatcher(reg)
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append((plugin_id, comp_name))
|
||||
return {"success": True, "continue_processing": True}
|
||||
|
||||
should_continue, modified = await dispatcher.dispatch_event(
|
||||
"on_start", mock_invoke
|
||||
)
|
||||
assert should_continue
|
||||
# 非阻塞分发是异步的,等一下让 task 完成
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(call_log) == 1
|
||||
assert call_log[0] == ("p1", "h1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_intercepting(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("filter", "event_handler", "p1", {
|
||||
"event_type": "on_message_pre_process",
|
||||
"weight": 100,
|
||||
"intercept_message": True,
|
||||
})
|
||||
|
||||
dispatcher = EventDispatcher(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {
|
||||
"success": True,
|
||||
"continue_processing": False,
|
||||
"modified_message": {"plain_text": "filtered"},
|
||||
}
|
||||
|
||||
should_continue, modified = await dispatcher.dispatch_event(
|
||||
"on_message_pre_process", mock_invoke, message={"plain_text": "hello"}
|
||||
)
|
||||
assert not should_continue
|
||||
assert modified is not None
|
||||
assert modified["plain_text"] == "filtered"
|
||||
|
||||
|
||||
# ─── MaiMessages 测试 ─────────────────────────────────────
|
||||
|
||||
class TestMaiMessages:
|
||||
"""统一消息模型测试"""
|
||||
|
||||
def test_create_and_serialize(self):
|
||||
from maibot_sdk.messages import MaiMessages, MessageSegment
|
||||
|
||||
msg = MaiMessages(
|
||||
message_segments=[MessageSegment(type="text", data={"text": "hello"})],
|
||||
plain_text="hello",
|
||||
stream_id="stream_1",
|
||||
)
|
||||
|
||||
d = msg.to_rpc_dict()
|
||||
assert d["plain_text"] == "hello"
|
||||
assert len(d["message_segments"]) == 1
|
||||
|
||||
msg2 = MaiMessages.from_rpc_dict(d)
|
||||
assert msg2.plain_text == "hello"
|
||||
|
||||
def test_deepcopy(self):
|
||||
from maibot_sdk.messages import MaiMessages
|
||||
|
||||
msg = MaiMessages(plain_text="original")
|
||||
msg2 = msg.deepcopy()
|
||||
msg2.plain_text = "modified"
|
||||
assert msg.plain_text == "original"
|
||||
|
||||
def test_modify_flags(self):
|
||||
from maibot_sdk.messages import MaiMessages
|
||||
from maibot_sdk.types import ModifyFlag
|
||||
|
||||
msg = MaiMessages(plain_text="hello")
|
||||
assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT)
|
||||
|
||||
msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False)
|
||||
assert not msg.modify_prompt("new prompt")
|
||||
assert msg.llm_prompt is None
|
||||
|
||||
assert msg.modify_response("new response")
|
||||
assert msg.llm_response_content == "new response"
|
||||
|
||||
|
||||
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
||||
|
||||
class TestWorkflowExecutor:
|
||||
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pipeline_completes(self):
|
||||
"""无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert result.return_message == "workflow completed"
|
||||
assert len(ctx.timings) == 6 # 6 stages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_hook_modifies_message(self):
|
||||
"""blocking hook 可以修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("upper", "workflow_step", "p1", {
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
msg = args.get("message", {})
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
|
||||
}
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "hello"},
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert final_msg["plain_text"] == "HELLO"
|
||||
assert len(ctx.modification_log) == 1
|
||||
assert ctx.modification_log[0].stage == "pre_process"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_stops_pipeline(self):
|
||||
"""HookResult.ABORT 立即终止 pipeline"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("blocker", "workflow_step", "p1", {
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "abort"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "aborted"
|
||||
assert result.stopped_at == "pre_process"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_stage(self):
|
||||
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# high-priority hook 返回 skip_stage
|
||||
reg.register_component("skipper", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
})
|
||||
# low-priority hook 不应被执行
|
||||
reg.register_component("checker", "workflow_step", "p2", {
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "skipper":
|
||||
return {"hook_result": "skip_stage"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, _ = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "test"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
# 只有 skipper 被调用,checker 被跳过
|
||||
assert call_log == ["skipper"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_filter(self):
|
||||
"""filter 条件不匹配时跳过 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("only_dm", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"filter": {"chat_type": "direct"},
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
# 不匹配 filter —— hook 不应被调用
|
||||
await executor.execute(
|
||||
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
|
||||
)
|
||||
assert call_log == []
|
||||
|
||||
# 匹配 filter —— hook 应被调用
|
||||
await executor.execute(
|
||||
mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}
|
||||
)
|
||||
assert call_log == ["only_dm"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_skip(self):
|
||||
"""error_policy=skip 时跳过失败的 hook 继续执行"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("failer", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
"error_policy": "skip",
|
||||
})
|
||||
reg.register_component("ok_step", "workflow_step", "p2", {
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "failer":
|
||||
raise RuntimeError("boom")
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "test"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert "failer" in call_log
|
||||
assert "ok_step" in call_log
|
||||
assert any("boom" in e for e in ctx.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_abort(self):
|
||||
"""error_policy=abort(默认)时 pipeline 失败"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("failer", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
# error_policy defaults to "abort"
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
raise RuntimeError("fatal")
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "test"}
|
||||
)
|
||||
assert result.status == "failed"
|
||||
assert result.stopped_at == "ingress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonblocking_hooks_concurrent(self):
|
||||
"""non-blocking hook 并发执行,不修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
for i in range(3):
|
||||
reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", {
|
||||
"stage": "post_process",
|
||||
"priority": 0,
|
||||
"blocking": False,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
||||
|
||||
result, final_msg, _ = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "original"}
|
||||
)
|
||||
# non-blocking 的 modified_message 被忽略
|
||||
assert final_msg["plain_text"] == "original"
|
||||
# 给异步 task 时间完成
|
||||
await asyncio.sleep(0.1)
|
||||
assert result.status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_routing(self):
|
||||
"""PLAN 阶段内置命令路由"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("help", "command", "p1", {
|
||||
"command_pattern": r"^/help",
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "help":
|
||||
return {"output": "帮助信息"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "/help topic"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert ctx.matched_command == "p1.help"
|
||||
cmd_result = ctx.get_stage_output("plan", "command_result")
|
||||
assert cmd_result is not None
|
||||
assert cmd_result["output"] == "帮助信息"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_outputs(self):
|
||||
"""stage_outputs 数据在阶段间传递"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# ingress 阶段写入数据
|
||||
reg.register_component("writer", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
# pre_process 阶段读取数据
|
||||
reg.register_component("reader", "workflow_step", "p2", {
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "writer":
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"stage_output": {"parsed_intent": "greeting"},
|
||||
}
|
||||
if comp_name == "reader":
|
||||
# 验证 stage_outputs 被传递过来
|
||||
outputs = args.get("stage_outputs", {})
|
||||
ingress_data = outputs.get("ingress", {})
|
||||
assert ingress_data.get("parsed_intent") == "greeting"
|
||||
return {"hook_result": "continue"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "hi"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||||
|
||||
Reference in New Issue
Block a user