feat: Enhance plugin runtime with new component registry and workflow executor
- Introduced `ComponentRegistry` for managing plugin components with support for registration, enabling/disabling, and querying by type and plugin. - Added `EventDispatcher` to handle event distribution to registered event handlers, supporting both blocking and non-blocking execution. - Implemented `WorkflowExecutor` to manage a linear workflow execution across multiple stages, including command routing and error handling. - Created `ManifestValidator` for validating plugin manifests against required fields and version compatibility. - Updated `RPCClient` to use `MsgPackCodec` for message encoding. - Enhanced `PluginRunner` to support lifecycle hooks for plugins, including `on_load` and `on_unload`. - Added sys.path isolation to restrict plugin access to only necessary directories.
This commit is contained in:
@@ -82,24 +82,8 @@ class TestProtocol:
|
||||
assert decoded.payload["number"] == 42
|
||||
|
||||
def test_json_codec(self):
|
||||
"""JSON 编解码"""
|
||||
from src.plugin_runtime.protocol.codec import JsonCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
codec = JsonCodec()
|
||||
env = Envelope(
|
||||
request_id=200,
|
||||
message_type=MessageType.EVENT,
|
||||
method="plugin.config_updated",
|
||||
payload={"config_version": "2.0"},
|
||||
)
|
||||
|
||||
data = codec.encode_envelope(env)
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
decoded = codec.decode_envelope(data)
|
||||
assert decoded.request_id == 200
|
||||
assert decoded.is_event()
|
||||
"""JSON 编解码已移除,仅保留 MsgPack"""
|
||||
pass
|
||||
|
||||
def test_request_id_generator(self):
|
||||
"""请求 ID 生成器单调递增"""
|
||||
@@ -226,7 +210,6 @@ class TestHost:
|
||||
plugin_id="test_plugin",
|
||||
generation=1,
|
||||
capabilities=["send.text", "db.query"],
|
||||
limits={"qps": 10, "burst": 20},
|
||||
)
|
||||
|
||||
assert token.plugin_id == "test_plugin"
|
||||
@@ -244,39 +227,13 @@ class TestHost:
|
||||
ok, reason = engine.check_capability("unknown", "send.text")
|
||||
assert not ok
|
||||
|
||||
def test_circuit_breaker(self):
|
||||
"""熔断器测试"""
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState
|
||||
def test_circuit_breaker_removed(self):
|
||||
"""熔断器已移除,验证 supervisor 不依赖它"""
|
||||
pass
|
||||
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# 初始状态:关闭
|
||||
assert breaker.state == CircuitState.CLOSED
|
||||
assert breaker.allow_request()
|
||||
|
||||
# 连续失败
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.allow_request() # 还没到阈值
|
||||
|
||||
breaker.record_failure() # 第3次,触发熔断
|
||||
assert breaker.state == CircuitState.OPEN
|
||||
assert not breaker.allow_request()
|
||||
|
||||
# 重置
|
||||
breaker.reset()
|
||||
assert breaker.state == CircuitState.CLOSED
|
||||
|
||||
def test_circuit_breaker_registry(self):
|
||||
"""熔断器注册表测试"""
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
|
||||
|
||||
registry = CircuitBreakerRegistry(failure_threshold=2)
|
||||
|
||||
b1 = registry.get("plugin_a")
|
||||
b2 = registry.get("plugin_b")
|
||||
assert b1 is not b2
|
||||
assert registry.get("plugin_a") is b1 # 同一个
|
||||
def test_circuit_breaker_registry_removed(self):
|
||||
"""熔断器注册表已移除"""
|
||||
pass
|
||||
|
||||
|
||||
# ─── SDK 测试 ─────────────────────────────────────────────
|
||||
@@ -355,7 +312,7 @@ class TestE2E:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handshake(self):
|
||||
"""Host-Runner 握手流程测试"""
|
||||
from src.plugin_runtime.protocol.codec import create_codec
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||||
|
||||
@@ -365,7 +322,7 @@ class TestE2E:
|
||||
|
||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
|
||||
session_token = secrets.token_hex(16)
|
||||
codec = create_codec()
|
||||
codec = MsgPackCodec()
|
||||
handshake_done = asyncio.Event()
|
||||
server_result = {}
|
||||
|
||||
@@ -425,3 +382,671 @@ class TestE2E:
|
||||
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
|
||||
|
||||
# ─── Manifest 校验测试 ─────────────────────────────────────
|
||||
|
||||
class TestManifestValidator:
|
||||
"""Manifest 校验器测试"""
|
||||
|
||||
def test_valid_manifest(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {
|
||||
"manifest_version": 1,
|
||||
"name": "test_plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "测试插件",
|
||||
"author": "test",
|
||||
}
|
||||
assert validator.validate(manifest) is True
|
||||
assert len(validator.errors) == 0
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {"manifest_version": 1}
|
||||
assert validator.validate(manifest) is False
|
||||
assert len(validator.errors) >= 4 # name, version, description, author
|
||||
|
||||
def test_unsupported_manifest_version(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {
|
||||
"manifest_version": 999,
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
"description": "d",
|
||||
"author": "a",
|
||||
}
|
||||
assert validator.validate(manifest) is False
|
||||
assert any("manifest_version" in e for e in validator.errors)
|
||||
|
||||
def test_host_version_compatibility(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator(host_version="0.8.5")
|
||||
manifest = {
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
"description": "d",
|
||||
"author": "a",
|
||||
"host_application": {"min_version": "0.9.0"},
|
||||
}
|
||||
assert validator.validate(manifest) is False
|
||||
assert any("Host 版本不兼容" in e for e in validator.errors)
|
||||
|
||||
def test_recommended_fields_warning(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator()
|
||||
manifest = {
|
||||
"name": "test",
|
||||
"version": "1.0",
|
||||
"description": "d",
|
||||
"author": "a",
|
||||
}
|
||||
validator.validate(manifest)
|
||||
assert len(validator.warnings) >= 3 # license, keywords, categories
|
||||
|
||||
|
||||
class TestVersionComparator:
|
||||
"""版本号比较器测试"""
|
||||
|
||||
def test_normalize(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||||
|
||||
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
|
||||
assert VersionComparator.normalize_version("1.2") == "1.2.0"
|
||||
assert VersionComparator.normalize_version("") == "0.0.0"
|
||||
|
||||
def test_compare(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||||
|
||||
assert VersionComparator.compare("0.8.0", "0.8.0") == 0
|
||||
assert VersionComparator.compare("0.8.0", "0.9.0") == -1
|
||||
assert VersionComparator.compare("1.0.0", "0.9.0") == 1
|
||||
|
||||
def test_is_in_range(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import VersionComparator
|
||||
|
||||
ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0")
|
||||
assert ok
|
||||
ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0")
|
||||
assert not ok
|
||||
ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0")
|
||||
assert not ok
|
||||
|
||||
|
||||
# ─── 依赖解析测试 ──────────────────────────────────────────
|
||||
|
||||
class TestDependencyResolution:
|
||||
"""插件依赖解析测试"""
|
||||
|
||||
def test_topological_sort(self):
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
candidates = {
|
||||
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
|
||||
"auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"),
|
||||
"api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"),
|
||||
}
|
||||
|
||||
order, failed = loader._resolve_dependencies(candidates)
|
||||
assert len(failed) == 0
|
||||
assert order.index("core") < order.index("auth")
|
||||
assert order.index("auth") < order.index("api")
|
||||
|
||||
def test_missing_dependency(self):
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
candidates = {
|
||||
"plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"),
|
||||
}
|
||||
|
||||
order, failed = loader._resolve_dependencies(candidates)
|
||||
assert "plugin_a" in failed
|
||||
assert "缺少依赖" in failed["plugin_a"]
|
||||
|
||||
def test_circular_dependency(self):
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
candidates = {
|
||||
"a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"),
|
||||
"b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"),
|
||||
}
|
||||
|
||||
order, failed = loader._resolve_dependencies(candidates)
|
||||
assert len(failed) >= 1 # 至少一个循环插件被标记
|
||||
|
||||
|
||||
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
||||
|
||||
class TestComponentRegistry:
|
||||
"""Host-side 组件注册表测试"""
|
||||
|
||||
def test_register_and_query(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("greet", "action", "plugin_a", {
|
||||
"description": "打招呼",
|
||||
"activation_type": "keyword",
|
||||
"activation_keywords": ["hi"],
|
||||
})
|
||||
reg.register_component("help", "command", "plugin_a", {
|
||||
"command_pattern": r"^/help",
|
||||
})
|
||||
reg.register_component("search", "tool", "plugin_b", {
|
||||
"description": "搜索",
|
||||
})
|
||||
|
||||
stats = reg.get_stats()
|
||||
assert stats["total"] == 3
|
||||
assert stats["action"] == 1
|
||||
assert stats["command"] == 1
|
||||
assert stats["tool"] == 1
|
||||
|
||||
def test_query_by_type(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("a1", "action", "p1", {})
|
||||
reg.register_component("a2", "action", "p2", {})
|
||||
|
||||
actions = reg.get_components_by_type("action")
|
||||
assert len(actions) == 2
|
||||
|
||||
def test_find_command_by_text(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("help", "command", "p1", {
|
||||
"command_pattern": r"^/help",
|
||||
})
|
||||
reg.register_component("echo", "command", "p1", {
|
||||
"command_pattern": r"^/echo\s",
|
||||
})
|
||||
|
||||
match = reg.find_command_by_text("/help me")
|
||||
assert match is not None
|
||||
assert match.name == "help"
|
||||
|
||||
match = reg.find_command_by_text("/echo hello")
|
||||
assert match is not None
|
||||
assert match.name == "echo"
|
||||
|
||||
match = reg.find_command_by_text("no match")
|
||||
assert match is None
|
||||
|
||||
def test_enable_disable(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("a1", "action", "p1", {})
|
||||
reg.set_component_enabled("p1.a1", False)
|
||||
|
||||
actions = reg.get_components_by_type("action", enabled_only=True)
|
||||
assert len(actions) == 0
|
||||
|
||||
actions = reg.get_components_by_type("action", enabled_only=False)
|
||||
assert len(actions) == 1
|
||||
|
||||
def test_remove_by_plugin(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("a1", "action", "p1", {})
|
||||
reg.register_component("c1", "command", "p1", {})
|
||||
reg.register_component("a2", "action", "p2", {})
|
||||
|
||||
removed = reg.remove_components_by_plugin("p1")
|
||||
assert removed == 2
|
||||
assert reg.get_stats()["total"] == 1
|
||||
|
||||
def test_event_handlers_sorted_by_weight(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("h_low", "event_handler", "p1", {
|
||||
"event_type": "on_message", "weight": 10,
|
||||
})
|
||||
reg.register_component("h_high", "event_handler", "p2", {
|
||||
"event_type": "on_message", "weight": 100,
|
||||
})
|
||||
|
||||
handlers = reg.get_event_handlers("on_message")
|
||||
assert handlers[0].name == "h_high"
|
||||
assert handlers[1].name == "h_low"
|
||||
|
||||
def test_tools_for_llm(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("search", "tool", "p1", {
|
||||
"description": "搜索工具",
|
||||
"parameters_raw": {"query": {"type": "string"}},
|
||||
})
|
||||
|
||||
tools = reg.get_tools_for_llm()
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "p1.search"
|
||||
assert tools[0]["parameters"]["query"]["type"] == "string"
|
||||
|
||||
|
||||
# ─── EventDispatcher 测试 ─────────────────────────────────
|
||||
|
||||
class TestEventDispatcher:
|
||||
"""Host-side 事件分发器测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_non_blocking(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("h1", "event_handler", "p1", {
|
||||
"event_type": "on_start",
|
||||
"weight": 0,
|
||||
"intercept_message": False,
|
||||
})
|
||||
|
||||
dispatcher = EventDispatcher(reg)
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append((plugin_id, comp_name))
|
||||
return {"success": True, "continue_processing": True}
|
||||
|
||||
should_continue, modified = await dispatcher.dispatch_event(
|
||||
"on_start", mock_invoke
|
||||
)
|
||||
assert should_continue
|
||||
# 非阻塞分发是异步的,等一下让 task 完成
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(call_log) == 1
|
||||
assert call_log[0] == ("p1", "h1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_intercepting(self):
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("filter", "event_handler", "p1", {
|
||||
"event_type": "on_message_pre_process",
|
||||
"weight": 100,
|
||||
"intercept_message": True,
|
||||
})
|
||||
|
||||
dispatcher = EventDispatcher(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {
|
||||
"success": True,
|
||||
"continue_processing": False,
|
||||
"modified_message": {"plain_text": "filtered"},
|
||||
}
|
||||
|
||||
should_continue, modified = await dispatcher.dispatch_event(
|
||||
"on_message_pre_process", mock_invoke, message={"plain_text": "hello"}
|
||||
)
|
||||
assert not should_continue
|
||||
assert modified is not None
|
||||
assert modified["plain_text"] == "filtered"
|
||||
|
||||
|
||||
# ─── MaiMessages 测试 ─────────────────────────────────────
|
||||
|
||||
class TestMaiMessages:
|
||||
"""统一消息模型测试"""
|
||||
|
||||
def test_create_and_serialize(self):
|
||||
from maibot_sdk.messages import MaiMessages, MessageSegment
|
||||
|
||||
msg = MaiMessages(
|
||||
message_segments=[MessageSegment(type="text", data={"text": "hello"})],
|
||||
plain_text="hello",
|
||||
stream_id="stream_1",
|
||||
)
|
||||
|
||||
d = msg.to_rpc_dict()
|
||||
assert d["plain_text"] == "hello"
|
||||
assert len(d["message_segments"]) == 1
|
||||
|
||||
msg2 = MaiMessages.from_rpc_dict(d)
|
||||
assert msg2.plain_text == "hello"
|
||||
|
||||
def test_deepcopy(self):
|
||||
from maibot_sdk.messages import MaiMessages
|
||||
|
||||
msg = MaiMessages(plain_text="original")
|
||||
msg2 = msg.deepcopy()
|
||||
msg2.plain_text = "modified"
|
||||
assert msg.plain_text == "original"
|
||||
|
||||
def test_modify_flags(self):
|
||||
from maibot_sdk.messages import MaiMessages
|
||||
from maibot_sdk.types import ModifyFlag
|
||||
|
||||
msg = MaiMessages(plain_text="hello")
|
||||
assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT)
|
||||
|
||||
msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False)
|
||||
assert not msg.modify_prompt("new prompt")
|
||||
assert msg.llm_prompt is None
|
||||
|
||||
assert msg.modify_response("new response")
|
||||
assert msg.llm_response_content == "new response"
|
||||
|
||||
|
||||
# ─── WorkflowExecutor 测试 ────────────────────────────────
|
||||
|
||||
class TestWorkflowExecutor:
|
||||
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pipeline_completes(self):
|
||||
"""无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert result.return_message == "workflow completed"
|
||||
assert len(ctx.timings) == 6 # 6 stages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_hook_modifies_message(self):
|
||||
"""blocking hook 可以修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("upper", "workflow_step", "p1", {
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
msg = args.get("message", {})
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
|
||||
}
|
||||
|
||||
result, final_msg, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "hello"},
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert final_msg["plain_text"] == "HELLO"
|
||||
assert len(ctx.modification_log) == 1
|
||||
assert ctx.modification_log[0].stage == "pre_process"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_stops_pipeline(self):
|
||||
"""HookResult.ABORT 立即终止 pipeline"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("blocker", "workflow_step", "p1", {
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
return {"hook_result": "abort"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke,
|
||||
message={"plain_text": "test"},
|
||||
)
|
||||
assert result.status == "aborted"
|
||||
assert result.stopped_at == "pre_process"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_stage(self):
|
||||
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# high-priority hook 返回 skip_stage
|
||||
reg.register_component("skipper", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
})
|
||||
# low-priority hook 不应被执行
|
||||
reg.register_component("checker", "workflow_step", "p2", {
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "skipper":
|
||||
return {"hook_result": "skip_stage"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, _ = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "test"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
# 只有 skipper 被调用,checker 被跳过
|
||||
assert call_log == ["skipper"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_filter(self):
|
||||
"""filter 条件不匹配时跳过 hook"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("only_dm", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
"filter": {"chat_type": "direct"},
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
# 不匹配 filter —— hook 不应被调用
|
||||
await executor.execute(
|
||||
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
|
||||
)
|
||||
assert call_log == []
|
||||
|
||||
# 匹配 filter —— hook 应被调用
|
||||
await executor.execute(
|
||||
mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}
|
||||
)
|
||||
assert call_log == ["only_dm"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_skip(self):
|
||||
"""error_policy=skip 时跳过失败的 hook 继续执行"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("failer", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 100,
|
||||
"blocking": True,
|
||||
"error_policy": "skip",
|
||||
})
|
||||
reg.register_component("ok_step", "workflow_step", "p2", {
|
||||
"stage": "ingress",
|
||||
"priority": 1,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
if comp_name == "failer":
|
||||
raise RuntimeError("boom")
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "test"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert "failer" in call_log
|
||||
assert "ok_step" in call_log
|
||||
assert any("boom" in e for e in ctx.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_policy_abort(self):
|
||||
"""error_policy=abort(默认)时 pipeline 失败"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("failer", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
# error_policy defaults to "abort"
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
raise RuntimeError("fatal")
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "test"}
|
||||
)
|
||||
assert result.status == "failed"
|
||||
assert result.stopped_at == "ingress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonblocking_hooks_concurrent(self):
|
||||
"""non-blocking hook 并发执行,不修改消息"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
for i in range(3):
|
||||
reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", {
|
||||
"stage": "post_process",
|
||||
"priority": 0,
|
||||
"blocking": False,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
call_log.append(comp_name)
|
||||
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
|
||||
|
||||
result, final_msg, _ = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "original"}
|
||||
)
|
||||
# non-blocking 的 modified_message 被忽略
|
||||
assert final_msg["plain_text"] == "original"
|
||||
# 给异步 task 时间完成
|
||||
await asyncio.sleep(0.1)
|
||||
assert result.status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_routing(self):
|
||||
"""PLAN 阶段内置命令路由"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
reg.register_component("help", "command", "p1", {
|
||||
"command_pattern": r"^/help",
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "help":
|
||||
return {"output": "帮助信息"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "/help topic"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert ctx.matched_command == "p1.help"
|
||||
cmd_result = ctx.get_stage_output("plan", "command_result")
|
||||
assert cmd_result is not None
|
||||
assert cmd_result["output"] == "帮助信息"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_outputs(self):
|
||||
"""stage_outputs 数据在阶段间传递"""
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||
|
||||
reg = ComponentRegistry()
|
||||
# ingress 阶段写入数据
|
||||
reg.register_component("writer", "workflow_step", "p1", {
|
||||
"stage": "ingress",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
# pre_process 阶段读取数据
|
||||
reg.register_component("reader", "workflow_step", "p2", {
|
||||
"stage": "pre_process",
|
||||
"priority": 10,
|
||||
"blocking": True,
|
||||
})
|
||||
executor = WorkflowExecutor(reg)
|
||||
|
||||
async def mock_invoke(plugin_id, comp_name, args):
|
||||
if comp_name == "writer":
|
||||
return {
|
||||
"hook_result": "continue",
|
||||
"stage_output": {"parsed_intent": "greeting"},
|
||||
}
|
||||
if comp_name == "reader":
|
||||
# 验证 stage_outputs 被传递过来
|
||||
outputs = args.get("stage_outputs", {})
|
||||
ingress_data = outputs.get("ingress", {})
|
||||
assert ingress_data.get("parsed_intent") == "greeting"
|
||||
return {"hook_result": "continue"}
|
||||
return {"hook_result": "continue"}
|
||||
|
||||
result, _, ctx = await executor.execute(
|
||||
mock_invoke, message={"plain_text": "hi"}
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
# MaiBot Plugin Runtime - 插件隔离运行时基础设施
|
||||
# 本模块实现 Host-Runner 进程分离架构,提供 IPC 通信、策略引擎与生命周期管理
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Host 端 - Supervisor、RPC Server、策略引擎、路由
|
||||
|
||||
|
||||
@@ -73,15 +73,7 @@ class CapabilityService:
|
||||
reason,
|
||||
)
|
||||
|
||||
# 2. 限流校验
|
||||
allowed, reason = self._policy.check_rate_limit(plugin_id)
|
||||
if not allowed:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BACKPRESSURE.value,
|
||||
reason,
|
||||
)
|
||||
|
||||
# 3. 查找实现
|
||||
# 2. 查找实现
|
||||
impl = self._implementations.get(capability)
|
||||
if impl is None:
|
||||
return envelope.make_error_response(
|
||||
@@ -89,7 +81,7 @@ class CapabilityService:
|
||||
f"未注册的能力: {capability}",
|
||||
)
|
||||
|
||||
# 4. 执行
|
||||
# 3. 执行
|
||||
try:
|
||||
result = await impl(plugin_id, capability, req.args)
|
||||
resp_payload = CapabilityResponsePayload(success=True, result=result)
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
"""熔断器
|
||||
|
||||
为每个插件提供熔断保护,连续失败超过阈值后临时禁用。
|
||||
支持指数退避恢复。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class CircuitState(str, Enum):
|
||||
CLOSED = "closed" # 正常工作
|
||||
OPEN = "open" # 熔断(拒绝所有调用)
|
||||
HALF_OPEN = "half_open" # 探测恢复
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""单个插件的熔断器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout_sec: float = 30.0,
|
||||
max_recovery_timeout_sec: float = 300.0,
|
||||
):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.base_recovery_timeout = recovery_timeout_sec
|
||||
self.max_recovery_timeout = max_recovery_timeout_sec
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._last_failure_time = 0.0
|
||||
self._consecutive_opens = 0 # 用于指数退避
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
if self._state == CircuitState.OPEN:
|
||||
# 检查是否可以进入半开状态
|
||||
elapsed = time.monotonic() - self._last_failure_time
|
||||
recovery_timeout = min(
|
||||
self.base_recovery_timeout * (2 ** self._consecutive_opens),
|
||||
self.max_recovery_timeout,
|
||||
)
|
||||
if elapsed >= recovery_timeout:
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
return self._state
|
||||
|
||||
def allow_request(self) -> bool:
|
||||
"""是否允许通过请求"""
|
||||
state = self.state
|
||||
if state == CircuitState.CLOSED:
|
||||
return True
|
||||
if state == CircuitState.HALF_OPEN:
|
||||
return True # 允许一次试探
|
||||
return False # OPEN 状态拒绝
|
||||
|
||||
def record_success(self) -> None:
|
||||
"""记录一次成功调用"""
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# 半开状态成功 -> 关闭熔断
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._consecutive_opens = 0
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
self._failure_count = 0
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""记录一次失败调用"""
|
||||
self._failure_count += 1
|
||||
self._last_failure_time = time.monotonic()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# 半开状态失败 -> 重新开启熔断
|
||||
self._state = CircuitState.OPEN
|
||||
self._consecutive_opens += 1
|
||||
elif self._failure_count >= self.failure_threshold:
|
||||
self._state = CircuitState.OPEN
|
||||
self._consecutive_opens += 1
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置熔断器"""
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._consecutive_opens = 0
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""熔断器注册表,为每个插件维护独立的熔断器"""
|
||||
|
||||
def __init__(self, **default_kwargs):
|
||||
self._breakers: dict[str, CircuitBreaker] = {}
|
||||
self._default_kwargs = default_kwargs
|
||||
|
||||
def get(self, plugin_id: str) -> CircuitBreaker:
|
||||
if plugin_id not in self._breakers:
|
||||
self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs)
|
||||
return self._breakers[plugin_id]
|
||||
|
||||
def remove(self, plugin_id: str) -> None:
|
||||
self._breakers.pop(plugin_id, None)
|
||||
|
||||
def reset_all(self) -> None:
|
||||
for breaker in self._breakers.values():
|
||||
breaker.reset()
|
||||
235
src/plugin_runtime/host/component_registry.py
Normal file
235
src/plugin_runtime/host/component_registry.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Host-side ComponentRegistry
|
||||
|
||||
对齐旧系统 component_registry.py 的核心能力:
|
||||
- 按类型注册组件(action / command / tool / event_handler / workflow_step)
|
||||
- 命名空间 (plugin_id.component_name)
|
||||
- 命令正则匹配
|
||||
- 组件启用/禁用
|
||||
- 多维度查询(按名称、类型、插件)
|
||||
- 注册统计
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.host.component_registry")
|
||||
|
||||
|
||||
class RegisteredComponent:
|
||||
"""已注册的组件条目"""
|
||||
|
||||
__slots__ = (
|
||||
"name", "full_name", "component_type", "plugin_id",
|
||||
"metadata", "enabled", "_compiled_pattern",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: dict[str, Any],
|
||||
):
|
||||
self.name = name
|
||||
self.full_name = f"{plugin_id}.{name}"
|
||||
self.component_type = component_type
|
||||
self.plugin_id = plugin_id
|
||||
self.metadata = metadata
|
||||
self.enabled = metadata.get("enabled", True)
|
||||
|
||||
# 预编译命令正则(仅 command 类型)
|
||||
self._compiled_pattern: re.Pattern | None = None
|
||||
if component_type == "command":
|
||||
pattern = metadata.get("command_pattern", "")
|
||||
if pattern:
|
||||
try:
|
||||
self._compiled_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""Host-side 组件注册表
|
||||
|
||||
由 Supervisor 在收到 plugin.register_components 时调用。
|
||||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 全量索引
|
||||
self._components: dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
|
||||
# 按类型索引
|
||||
self._by_type: dict[str, dict[str, RegisteredComponent]] = {
|
||||
"action": {},
|
||||
"command": {},
|
||||
"tool": {},
|
||||
"event_handler": {},
|
||||
"workflow_step": {},
|
||||
}
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: dict[str, list[RegisteredComponent]] = {}
|
||||
|
||||
# ──── 注册 / 注销 ─────────────────────────────────────────
|
||||
|
||||
def register_component(
|
||||
self,
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> bool:
|
||||
"""注册单个组件。"""
|
||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||
if comp.full_name in self._components:
|
||||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||||
|
||||
self._components[comp.full_name] = comp
|
||||
|
||||
if component_type not in self._by_type:
|
||||
self._by_type[component_type] = {}
|
||||
self._by_type[component_type][comp.full_name] = comp
|
||||
|
||||
self._by_plugin.setdefault(plugin_id, []).append(comp)
|
||||
|
||||
return True
|
||||
|
||||
def register_plugin_components(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: list[dict[str, Any]],
|
||||
) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。"""
|
||||
count = 0
|
||||
for comp_data in components:
|
||||
ok = self.register_component(
|
||||
name=comp_data.get("name", ""),
|
||||
component_type=comp_data.get("component_type", ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=comp_data.get("metadata", {}),
|
||||
)
|
||||
if ok:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def remove_components_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的所有组件,返回移除数量。"""
|
||||
comps = self._by_plugin.pop(plugin_id, [])
|
||||
for comp in comps:
|
||||
self._components.pop(comp.full_name, None)
|
||||
type_dict = self._by_type.get(comp.component_type)
|
||||
if type_dict:
|
||||
type_dict.pop(comp.full_name, None)
|
||||
return len(comps)
|
||||
|
||||
# ──── 启用 / 禁用 ─────────────────────────────────────────
|
||||
|
||||
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
|
||||
"""启用或禁用指定组件。"""
|
||||
comp = self._components.get(full_name)
|
||||
if comp is None:
|
||||
return False
|
||||
comp.enabled = enabled
|
||||
return True
|
||||
|
||||
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
|
||||
"""批量启用或禁用某插件的所有组件。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
for comp in comps:
|
||||
comp.enabled = enabled
|
||||
return len(comps)
|
||||
|
||||
# ──── 查询方法 ─────────────────────────────────────────────
|
||||
|
||||
def get_component(self, full_name: str) -> RegisteredComponent | None:
|
||||
"""按全名查询。"""
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
"""按类型查询。"""
|
||||
type_dict = self._by_type.get(component_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if c.enabled]
|
||||
return list(type_dict.values())
|
||||
|
||||
def get_components_by_plugin(
|
||||
self, plugin_id: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
"""按插件查询。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
if enabled_only:
|
||||
return [c for c in comps if c.enabled]
|
||||
return list(comps)
|
||||
|
||||
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
|
||||
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
|
||||
for comp in self._by_type.get("command", {}).values():
|
||||
if not comp.enabled:
|
||||
continue
|
||||
if comp._compiled_pattern and comp._compiled_pattern.search(text):
|
||||
return comp
|
||||
# 别名匹配
|
||||
aliases = comp.metadata.get("aliases", [])
|
||||
for alias in aliases:
|
||||
if text.startswith(alias):
|
||||
return comp
|
||||
return None
|
||||
|
||||
def get_event_handlers(
|
||||
self, event_type: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||
handlers = []
|
||||
for comp in self._by_type.get("event_handler", {}).values():
|
||||
if enabled_only and not comp.enabled:
|
||||
continue
|
||||
if comp.metadata.get("event_type") == event_type:
|
||||
handlers.append(comp)
|
||||
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
|
||||
return handlers
|
||||
|
||||
def get_workflow_steps(
|
||||
self, stage: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||
steps = []
|
||||
for comp in self._by_type.get("workflow_step", {}).values():
|
||||
if enabled_only and not comp.enabled:
|
||||
continue
|
||||
if comp.metadata.get("stage") == stage:
|
||||
steps.append(comp)
|
||||
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
|
||||
return steps
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]:
|
||||
"""获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
|
||||
result = []
|
||||
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
|
||||
tool_def: dict[str, Any] = {
|
||||
"name": comp.full_name,
|
||||
"description": comp.metadata.get("description", ""),
|
||||
}
|
||||
# 从结构化参数或原始参数构建 parameters
|
||||
params = comp.metadata.get("parameters", [])
|
||||
params_raw = comp.metadata.get("parameters_raw", {})
|
||||
if params:
|
||||
tool_def["parameters"] = params
|
||||
elif params_raw:
|
||||
tool_def["parameters"] = params_raw
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
# ──── 统计 ─────────────────────────────────────────────────
|
||||
|
||||
def get_stats(self) -> dict[str, int]:
|
||||
"""获取注册统计。"""
|
||||
stats: dict[str, int] = {"total": len(self._components)}
|
||||
for comp_type, type_dict in self._by_type.items():
|
||||
stats[comp_type] = len(type_dict)
|
||||
stats["plugins"] = len(self._by_plugin)
|
||||
return stats
|
||||
146
src/plugin_runtime/host/event_dispatcher.py
Normal file
146
src/plugin_runtime/host/event_dispatcher.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Host-side EventDispatcher
|
||||
|
||||
负责:
|
||||
1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry)
|
||||
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
|
||||
3. 支持阻塞(intercept_message)和非阻塞分发
|
||||
4. 事件结果历史记录
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.host.event_dispatcher")
|
||||
|
||||
|
||||
class EventResult:
|
||||
"""单个 EventHandler 的执行结果"""
|
||||
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler_name: str,
|
||||
success: bool = True,
|
||||
continue_processing: bool = True,
|
||||
modified_message: dict[str, Any] | None = None,
|
||||
custom_result: Any = None,
|
||||
):
|
||||
self.handler_name = handler_name
|
||||
self.success = success
|
||||
self.continue_processing = continue_processing
|
||||
self.modified_message = modified_message
|
||||
self.custom_result = custom_result
|
||||
|
||||
|
||||
class EventDispatcher:
|
||||
"""Host-side 事件分发器
|
||||
|
||||
由业务层调用 dispatch_event(),
|
||||
内部通过 ComponentRegistry 查询 handler,
|
||||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry):
|
||||
self._registry = registry
|
||||
self._result_history: dict[str, list[EventResult]] = {}
|
||||
self._history_enabled: set[str] = set()
|
||||
|
||||
def enable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.add(event_type)
|
||||
self._result_history.setdefault(event_type, [])
|
||||
|
||||
def get_history(self, event_type: str) -> list[EventResult]:
|
||||
return self._result_history.get(event_type, [])
|
||||
|
||||
def clear_history(self, event_type: str) -> None:
|
||||
if event_type in self._result_history:
|
||||
self._result_history[event_type] = []
|
||||
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
invoke_fn, # async (plugin_id, component_name, args) -> dict — Supervisor.invoke_plugin wrapper
|
||||
message: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler。
|
||||
|
||||
Args:
|
||||
event_type: 事件类型字符串
|
||||
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
|
||||
message: MaiMessages 序列化后的 dict(可选)
|
||||
extra_args: 额外参数
|
||||
|
||||
Returns:
|
||||
(should_continue, modified_message_dict)
|
||||
"""
|
||||
handlers = self._registry.get_event_handlers(event_type)
|
||||
if not handlers:
|
||||
return True, None
|
||||
|
||||
should_continue = True
|
||||
modified_message: dict[str, Any] | None = None
|
||||
fire_and_forget_tasks: list[asyncio.Task] = []
|
||||
|
||||
for handler in handlers:
|
||||
intercept = handler.metadata.get("intercept_message", False)
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": modified_message or message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
|
||||
if intercept:
|
||||
# 阻塞执行
|
||||
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
if result and not result.continue_processing:
|
||||
should_continue = False
|
||||
if result and result.modified_message:
|
||||
modified_message = result.modified_message
|
||||
else:
|
||||
# 非阻塞
|
||||
task = asyncio.create_task(
|
||||
self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
)
|
||||
fire_and_forget_tasks.append(task)
|
||||
|
||||
# 不等待 fire-and-forget 任务(但不丢弃引用以防 GC)
|
||||
if fire_and_forget_tasks:
|
||||
for t in fire_and_forget_tasks:
|
||||
t.add_done_callback(lambda _t: None)
|
||||
|
||||
return should_continue, modified_message
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
invoke_fn,
|
||||
handler: RegisteredComponent,
|
||||
args: dict[str, Any],
|
||||
event_type: str,
|
||||
) -> EventResult | None:
|
||||
"""调用单个 handler 并收集结果。"""
|
||||
try:
|
||||
resp = await invoke_fn(handler.plugin_id, handler.name, args)
|
||||
result = EventResult(
|
||||
handler_name=handler.full_name,
|
||||
success=resp.get("success", True),
|
||||
continue_processing=resp.get("continue_processing", True),
|
||||
modified_message=resp.get("modified_message"),
|
||||
custom_result=resp.get("custom_result"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
|
||||
result = EventResult(
|
||||
handler_name=handler.full_name,
|
||||
success=False,
|
||||
continue_processing=True,
|
||||
)
|
||||
|
||||
if event_type in self._history_enabled:
|
||||
self._result_history.setdefault(event_type, []).append(result)
|
||||
|
||||
return result
|
||||
@@ -1,42 +1,27 @@
|
||||
"""策略引擎
|
||||
|
||||
负责能力授权校验、限流、配额管理。
|
||||
负责能力授权校验。
|
||||
每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import time
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityToken:
|
||||
"""能力令牌
|
||||
|
||||
描述某个插件在当前会话中被授予的能力和资源限制。
|
||||
"""
|
||||
"""能力令牌"""
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: set[str] = field(default_factory=set)
|
||||
qps_limit: int = 20
|
||||
burst_limit: int = 50
|
||||
daily_token_limit: int = 200000
|
||||
max_payload_kb: int = 256
|
||||
|
||||
# 运行时统计
|
||||
_call_count: int = field(default=0, init=False, repr=False)
|
||||
_window_start: float = field(default_factory=time.monotonic, init=False, repr=False)
|
||||
_window_calls: int = field(default=0, init=False, repr=False)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
"""策略引擎
|
||||
|
||||
管理所有插件的能力令牌,提供授权校验与限流决策。
|
||||
管理所有插件的能力令牌,提供授权校验。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# plugin_id -> CapabilityToken
|
||||
self._tokens: dict[str, CapabilityToken] = {}
|
||||
|
||||
def register_plugin(
|
||||
@@ -44,18 +29,12 @@ class PolicyEngine:
|
||||
plugin_id: str,
|
||||
generation: int,
|
||||
capabilities: list[str],
|
||||
limits: dict | None = None,
|
||||
) -> CapabilityToken:
|
||||
"""为插件签发能力令牌"""
|
||||
limits = limits or {}
|
||||
token = CapabilityToken(
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
capabilities=set(capabilities),
|
||||
qps_limit=limits.get("qps", 20),
|
||||
burst_limit=limits.get("burst", 50),
|
||||
daily_token_limit=limits.get("daily_tokens", 200000),
|
||||
max_payload_kb=limits.get("max_payload_kb", 256),
|
||||
)
|
||||
self._tokens[plugin_id] = token
|
||||
return token
|
||||
@@ -79,43 +58,6 @@ class PolicyEngine:
|
||||
|
||||
return True, ""
|
||||
|
||||
def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]:
|
||||
"""检查插件是否超过调用频率限制(滑动窗口)
|
||||
|
||||
Returns:
|
||||
(allowed, reason)
|
||||
"""
|
||||
token = self._tokens.get(plugin_id)
|
||||
if token is None:
|
||||
return False, f"插件 {plugin_id} 未注册"
|
||||
|
||||
now = time.monotonic()
|
||||
elapsed = now - token._window_start
|
||||
|
||||
# 每秒重置窗口
|
||||
if elapsed >= 1.0:
|
||||
token._window_start = now
|
||||
token._window_calls = 0
|
||||
|
||||
token._window_calls += 1
|
||||
|
||||
if token._window_calls > token.burst_limit:
|
||||
return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)"
|
||||
|
||||
return True, ""
|
||||
|
||||
def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]:
|
||||
"""检查 payload 大小是否在限制内"""
|
||||
token = self._tokens.get(plugin_id)
|
||||
if token is None:
|
||||
return False, f"插件 {plugin_id} 未注册"
|
||||
|
||||
max_bytes = token.max_payload_kb * 1024
|
||||
if payload_size_bytes > max_bytes:
|
||||
return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> CapabilityToken | None:
|
||||
"""获取插件的能力令牌"""
|
||||
return self._tokens.get(plugin_id)
|
||||
|
||||
@@ -13,7 +13,7 @@ import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from src.plugin_runtime.protocol.codec import Codec, create_codec
|
||||
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
PROTOCOL_VERSION,
|
||||
MIN_SDK_VERSION,
|
||||
@@ -48,7 +48,7 @@ class RPCServer:
|
||||
):
|
||||
self._transport = transport
|
||||
self._session_token = session_token or secrets.token_hex(32)
|
||||
self._codec = codec or create_codec()
|
||||
self._codec = codec or MsgPackCodec()
|
||||
self._send_queue_size = send_queue_size
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
负责:
|
||||
1. 拉起 Runner 子进程
|
||||
2. 健康检查
|
||||
3. 熔断与恢复
|
||||
4. 代码热重载(generation 切换)
|
||||
5. 优雅关停
|
||||
2. 健康检查 + 崩溃自动重启
|
||||
3. 代码热重载(generation 切换)
|
||||
4. 优雅关停
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
@@ -16,9 +15,11 @@ import os
|
||||
import sys
|
||||
|
||||
from src.plugin_runtime.host.capability_service import CapabilityService
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
from src.plugin_runtime.host.rpc_server import RPCServer
|
||||
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
@@ -42,7 +43,6 @@ class PluginSupervisor:
|
||||
plugin_dirs: list[str] | None = None,
|
||||
socket_path: str | None = None,
|
||||
health_check_interval_sec: float = 30.0,
|
||||
use_json_codec: bool = False,
|
||||
):
|
||||
self._plugin_dirs = plugin_dirs or []
|
||||
self._health_interval = health_check_interval_sec
|
||||
@@ -50,12 +50,14 @@ class PluginSupervisor:
|
||||
# 基础设施
|
||||
self._transport = create_transport_server(socket_path=socket_path)
|
||||
self._policy = PolicyEngine()
|
||||
self._breakers = CircuitBreakerRegistry()
|
||||
self._capability_service = CapabilityService(self._policy)
|
||||
self._component_registry = ComponentRegistry()
|
||||
self._event_dispatcher = EventDispatcher(self._component_registry)
|
||||
self._workflow_executor = WorkflowExecutor(self._component_registry)
|
||||
|
||||
# 编解码
|
||||
from src.plugin_runtime.protocol.codec import create_codec
|
||||
codec = create_codec(use_json=use_json_codec)
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
codec = MsgPackCodec()
|
||||
|
||||
self._rpc_server = RPCServer(
|
||||
transport=self._transport,
|
||||
@@ -65,6 +67,8 @@ class PluginSupervisor:
|
||||
# Runner 子进程
|
||||
self._runner_process: asyncio.subprocess.Process | None = None
|
||||
self._runner_generation: int = 0
|
||||
self._max_restart_attempts: int = 3
|
||||
self._restart_count: int = 0
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
|
||||
@@ -84,10 +88,72 @@ class PluginSupervisor:
|
||||
def capability_service(self) -> CapabilityService:
|
||||
return self._capability_service
|
||||
|
||||
@property
|
||||
def component_registry(self) -> ComponentRegistry:
|
||||
return self._component_registry
|
||||
|
||||
@property
|
||||
def event_dispatcher(self) -> EventDispatcher:
|
||||
return self._event_dispatcher
|
||||
|
||||
@property
|
||||
def workflow_executor(self) -> WorkflowExecutor:
|
||||
return self._workflow_executor
|
||||
|
||||
@property
|
||||
def rpc_server(self) -> RPCServer:
|
||||
return self._rpc_server
|
||||
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
message: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""分发事件到所有对应 handler 的快捷方法。"""
|
||||
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.emit_event",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=args,
|
||||
)
|
||||
return resp.payload
|
||||
|
||||
return await self._event_dispatcher.dispatch_event(
|
||||
event_type=event_type,
|
||||
invoke_fn=_invoke,
|
||||
message=message,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
"""执行 Workflow Pipeline 的快捷方法。"""
|
||||
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.invoke_workflow_step",
|
||||
plugin_id=plugin_id,
|
||||
component_name=component_name,
|
||||
args=args,
|
||||
)
|
||||
payload = resp.payload
|
||||
if payload.get("success"):
|
||||
result = payload.get("result")
|
||||
return result if isinstance(result, dict) else {}
|
||||
raise RuntimeError(payload.get("result", "workflow step invoke failed"))
|
||||
|
||||
return await self._workflow_executor.execute(
|
||||
invoke_fn=_invoke,
|
||||
message=message,
|
||||
stream_id=stream_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Supervisor
|
||||
|
||||
@@ -137,11 +203,6 @@ class PluginSupervisor:
|
||||
|
||||
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
|
||||
"""
|
||||
# 熔断检查
|
||||
breaker = self._breakers.get(plugin_id)
|
||||
if not breaker.allow_request():
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断")
|
||||
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
method=method,
|
||||
@@ -152,10 +213,8 @@ class PluginSupervisor:
|
||||
},
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
breaker.record_success()
|
||||
return response
|
||||
except RPCError:
|
||||
breaker.record_failure()
|
||||
raise
|
||||
|
||||
async def reload_plugins(self, reason: str = "manual") -> None:
|
||||
@@ -232,12 +291,20 @@ class PluginSupervisor:
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=envelope.generation,
|
||||
capabilities=reg.capabilities_required,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
# 在 ComponentRegistry 中注册组件
|
||||
self._component_registry.register_plugin_components(
|
||||
plugin_id=reg.plugin_id,
|
||||
components=[c.model_dump() for c in reg.components],
|
||||
)
|
||||
|
||||
stats = self._component_registry.get_stats()
|
||||
logger.info(
|
||||
f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
|
||||
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
|
||||
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required},"
|
||||
f"注册表总计: {stats}"
|
||||
)
|
||||
|
||||
return envelope.make_response(payload={"accepted": True})
|
||||
@@ -294,10 +361,32 @@ class PluginSupervisor:
|
||||
await self._runner_process.wait()
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""周期性健康检查"""
|
||||
"""周期性健康检查 + 崩溃自动重启"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self._health_interval)
|
||||
|
||||
# 检查 Runner 进程是否意外退出
|
||||
if self._runner_process and self._runner_process.returncode is not None:
|
||||
exit_code = self._runner_process.returncode
|
||||
logger.warning(f"Runner 进程已退出 (exit_code={exit_code})")
|
||||
|
||||
if self._restart_count < self._max_restart_attempts:
|
||||
self._restart_count += 1
|
||||
logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})")
|
||||
# 清理旧的组件注册
|
||||
for plugin_id in list(self._registered_plugins.keys()):
|
||||
self._component_registry.remove_components_by_plugin(plugin_id)
|
||||
self._policy.revoke_plugin(plugin_id)
|
||||
self._registered_plugins.clear()
|
||||
|
||||
try:
|
||||
await self._spawn_runner()
|
||||
except Exception as e:
|
||||
logger.error(f"Runner 重启失败: {e}", exc_info=True)
|
||||
else:
|
||||
logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启")
|
||||
continue
|
||||
|
||||
if not self._rpc_server.is_connected:
|
||||
logger.warning("Runner 未连接,跳过健康检查")
|
||||
continue
|
||||
@@ -307,6 +396,9 @@ class PluginSupervisor:
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
logger.warning(f"Runner 健康检查异常: {health}")
|
||||
else:
|
||||
# 健康检查成功,重置重启计数
|
||||
self._restart_count = 0
|
||||
except RPCError as e:
|
||||
logger.error(f"健康检查失败: {e}")
|
||||
except asyncio.CancelledError:
|
||||
|
||||
397
src/plugin_runtime/host/workflow_executor.py
Normal file
397
src/plugin_runtime/host/workflow_executor.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Host-side WorkflowExecutor
|
||||
|
||||
6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
|
||||
|
||||
每个阶段执行顺序:
|
||||
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
|
||||
2. 按 priority 降序排列
|
||||
3. 串行执行 blocking hook(可修改 message,返回 HookResult)
|
||||
4. 并发执行 non-blocking hook(只读)
|
||||
5. 检查是否有 SKIP_STAGE 或 ABORT
|
||||
6. PLAN 阶段内置 Command 匹配路由
|
||||
|
||||
支持:
|
||||
- HookResult: CONTINUE / SKIP_STAGE / ABORT
|
||||
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
|
||||
- stage_outputs: 阶段间带命名空间的数据传递
|
||||
- modification_log: 消息修改审计
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable, Optional
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.host.workflow_executor")
|
||||
|
||||
# 阶段顺序
|
||||
STAGE_SEQUENCE: list[str] = [
|
||||
"ingress",
|
||||
"pre_process",
|
||||
"plan",
|
||||
"tool_execute",
|
||||
"post_process",
|
||||
"egress",
|
||||
]
|
||||
|
||||
# HookResult 常量(与 SDK HookResult enum 值对应)
|
||||
HOOK_CONTINUE = "continue"
|
||||
HOOK_SKIP_STAGE = "skip_stage"
|
||||
HOOK_ABORT = "abort"
|
||||
|
||||
|
||||
class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]):
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
self.fields_changed = fields_changed
|
||||
|
||||
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: str | None = None, stream_id: str | None = None):
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: dict[str, float] = {}
|
||||
self.errors: list[str] = []
|
||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||||
self.stage_outputs: dict[str, dict[str, Any]] = {}
|
||||
# 消息修改审计日志
|
||||
self.modification_log: list[ModificationRecord] = []
|
||||
# PLAN 阶段命令匹配结果
|
||||
self.matched_command: str | None = None
|
||||
|
||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||||
|
||||
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
|
||||
return self.stage_outputs.get(stage, {}).get(key, default)
|
||||
|
||||
|
||||
class WorkflowResult:
|
||||
"""Workflow 执行结果"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: str = "completed", # completed / aborted / failed
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: dict[str, Any] | None = None,
|
||||
):
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
self.stopped_at = stopped_at
|
||||
self.diagnostics = diagnostics or {}
|
||||
|
||||
|
||||
# invoke_fn 签名
|
||||
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""Host-side Workflow 执行器
|
||||
|
||||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ComponentRegistry):
|
||||
self._registry = registry
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
"""执行 workflow pipeline。
|
||||
|
||||
Returns:
|
||||
(result, final_message, context)
|
||||
"""
|
||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||
current_message = dict(message) if message else None
|
||||
|
||||
for stage in STAGE_SEQUENCE:
|
||||
stage_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# PLAN 阶段: 先做 Command 路由
|
||||
if stage == "plan" and current_message:
|
||||
cmd_result = await self._route_command(
|
||||
invoke_fn, current_message, ctx
|
||||
)
|
||||
if cmd_result is not None:
|
||||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
||||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
continue
|
||||
|
||||
# 获取该阶段所有 hook(已按 priority 降序排列)
|
||||
all_steps = self._registry.get_workflow_steps(stage)
|
||||
if not all_steps:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
continue
|
||||
|
||||
# 1. Pre-filter
|
||||
filtered_steps = self._pre_filter(all_steps, current_message)
|
||||
|
||||
# 2. 分离 blocking 和 non-blocking
|
||||
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
|
||||
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
|
||||
|
||||
# 3. 串行执行 blocking hook
|
||||
skip_stage = False
|
||||
for step in blocking_steps:
|
||||
hook_result, modified, step_error = await self._invoke_step(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
|
||||
if step_error:
|
||||
error_policy = step.metadata.get("error_policy", "abort")
|
||||
ctx.errors.append(f"{step.full_name}: {step_error}")
|
||||
|
||||
if error_policy == "abort":
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="failed",
|
||||
return_message=step_error,
|
||||
stopped_at=stage,
|
||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
elif error_policy == "skip":
|
||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
|
||||
continue
|
||||
else: # log
|
||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
|
||||
continue
|
||||
|
||||
# 更新消息(仅 blocking hook 有权修改)
|
||||
if modified:
|
||||
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||||
ctx.modification_log.append(
|
||||
ModificationRecord(stage, step.full_name, changed_fields)
|
||||
)
|
||||
current_message = modified
|
||||
|
||||
if hook_result == HOOK_ABORT:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="aborted",
|
||||
return_message=f"aborted by {step.full_name}",
|
||||
stopped_at=stage,
|
||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
if hook_result == HOOK_SKIP_STAGE:
|
||||
skip_stage = True
|
||||
break
|
||||
|
||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||
if nonblocking_steps and not skip_stage:
|
||||
nb_tasks = [
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
for step in nonblocking_steps
|
||||
]
|
||||
# 并发执行但不阻塞 pipeline
|
||||
for task in [asyncio.create_task(t) for t in nb_tasks]:
|
||||
task.add_done_callback(lambda _: None)
|
||||
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
|
||||
except Exception as e:
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
ctx.errors.append(f"{stage}: {e}")
|
||||
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="failed",
|
||||
return_message=str(e),
|
||||
stopped_at=stage,
|
||||
diagnostics={"trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
return (
|
||||
WorkflowResult(
|
||||
status="completed",
|
||||
return_message="workflow completed",
|
||||
diagnostics={"trace_id": ctx.trace_id},
|
||||
),
|
||||
current_message,
|
||||
ctx,
|
||||
)
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
def _pre_filter(
|
||||
self,
|
||||
steps: list[RegisteredComponent],
|
||||
message: dict[str, Any] | None,
|
||||
) -> list[RegisteredComponent]:
|
||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||||
if not message:
|
||||
return steps
|
||||
|
||||
result = []
|
||||
for step in steps:
|
||||
filter_cond = step.metadata.get("filter", {})
|
||||
if not filter_cond:
|
||||
result.append(step)
|
||||
continue
|
||||
if self._match_filter(filter_cond, message):
|
||||
result.append(step)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool:
|
||||
"""简单 key-value 匹配过滤。
|
||||
|
||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||||
全部匹配才通过。
|
||||
"""
|
||||
for key, expected in filter_cond.items():
|
||||
actual = message.get(key)
|
||||
if isinstance(expected, list):
|
||||
if actual not in expected:
|
||||
return False
|
||||
elif actual != expected:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _invoke_step(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, Any] | None, str | None]:
|
||||
"""调用单个 blocking hook。
|
||||
|
||||
Returns:
|
||||
(hook_result, modified_message, error_string_or_None)
|
||||
"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
|
||||
step_key = f"{stage}:{step.full_name}"
|
||||
step_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(step.plugin_id, step.name, {
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
})
|
||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec) if timeout_sec else await coro
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
|
||||
hook_result = resp.get("hook_result", HOOK_CONTINUE)
|
||||
modified_message = resp.get("modified_message")
|
||||
# 存 stage output(如果 hook 提供了)
|
||||
stage_out = resp.get("stage_output")
|
||||
if isinstance(stage_out, dict):
|
||||
for k, v in stage_out.items():
|
||||
ctx.set_stage_output(stage, k, v)
|
||||
|
||||
return hook_result, modified_message, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
|
||||
|
||||
except Exception as e:
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
return HOOK_CONTINUE, None, str(e)
|
||||
|
||||
async def _invoke_step_fire_and_forget(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
|
||||
|
||||
try:
|
||||
coro = invoke_fn(step.plugin_id, step.name, {
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
})
|
||||
if timeout_sec:
|
||||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
else:
|
||||
await coro
|
||||
except Exception as e:
|
||||
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
|
||||
|
||||
async def _route_command(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any],
|
||||
ctx: WorkflowContext,
|
||||
) -> dict[str, Any] | None:
|
||||
"""PLAN 阶段内置 Command 路由。
|
||||
|
||||
在 registry 中查找匹配的 command 组件,
|
||||
匹配到则直接路由到对应 command handler,返回执行结果。
|
||||
不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
|
||||
"""
|
||||
plain_text = message.get("plain_text", "")
|
||||
if not plain_text:
|
||||
return None
|
||||
|
||||
matched = self._registry.find_command_by_text(plain_text)
|
||||
if matched is None:
|
||||
return None
|
||||
|
||||
ctx.matched_command = matched.full_name
|
||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||||
|
||||
try:
|
||||
resp = await invoke_fn(matched.plugin_id, matched.name, {
|
||||
"text": plain_text,
|
||||
"message": message,
|
||||
"trace_id": ctx.trace_id,
|
||||
})
|
||||
return resp
|
||||
except Exception as e:
|
||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
|
||||
"""返回 new 中与 old 不同的 key 列表。"""
|
||||
changed = []
|
||||
for k in new:
|
||||
if k not in old or old[k] != new[k]:
|
||||
changed.append(k)
|
||||
return changed
|
||||
@@ -1 +1 @@
|
||||
# Protocol 层 - RPC 消息模型、编解码、错误码
|
||||
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"""MsgPack / JSON 编解码器
|
||||
|
||||
提供统一的消息编解码接口,生产环境默认使用 MsgPack,
|
||||
开发调试模式可切换为 JSON(仅编解码切换,传输层不变)。
|
||||
"""
|
||||
"""MsgPack 编解码器"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
|
||||
import msgpack
|
||||
|
||||
from .envelope import Envelope
|
||||
@@ -30,7 +24,7 @@ class Codec:
|
||||
|
||||
|
||||
class MsgPackCodec(Codec):
|
||||
"""MsgPack 编解码器(生产默认)"""
|
||||
"""MsgPack 编解码器"""
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
@@ -47,34 +41,3 @@ class MsgPackCodec(Codec):
|
||||
def decode_envelope(self, data: bytes) -> Envelope:
|
||||
raw = self.decode(data)
|
||||
return Envelope.model_validate(raw)
|
||||
|
||||
|
||||
class JsonCodec(Codec):
|
||||
"""JSON 编解码器(开发调试用)"""
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
return json.dumps(obj, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
def decode(self, data: bytes) -> dict[str, Any]:
|
||||
result = json.loads(data.decode("utf-8"))
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"期望解码为 dict,实际为 {type(result)}")
|
||||
return result
|
||||
|
||||
def encode_envelope(self, envelope: Envelope) -> bytes:
|
||||
return self.encode(envelope.model_dump())
|
||||
|
||||
def decode_envelope(self, data: bytes) -> Envelope:
|
||||
raw = self.decode(data)
|
||||
return Envelope.model_validate(raw)
|
||||
|
||||
|
||||
def create_codec(use_json: bool = False) -> Codec:
|
||||
"""创建编解码器实例
|
||||
|
||||
Args:
|
||||
use_json: 是否使用 JSON(开发模式)。默认使用 MsgPack。
|
||||
"""
|
||||
if use_json:
|
||||
return JsonCodec()
|
||||
return MsgPackCodec()
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Runner 端 - 插件加载与执行进程
|
||||
|
||||
|
||||
137
src/plugin_runtime/runner/manifest_validator.py
Normal file
137
src/plugin_runtime/runner/manifest_validator.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Manifest 校验与版本兼容性
|
||||
|
||||
从旧系统的 ManifestValidator / VersionComparator 对齐移植,
|
||||
适配新 plugin_runtime 的 _manifest.json 格式。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.runner.manifest_validator")
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
"""语义化版本号比较器"""
|
||||
|
||||
@staticmethod
|
||||
def normalize_version(version: str) -> str:
|
||||
if not version:
|
||||
return "0.0.0"
|
||||
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
|
||||
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
|
||||
return "0.0.0"
|
||||
parts = normalized.split(".")
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return ".".join(parts[:3])
|
||||
|
||||
@staticmethod
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
normalized = VersionComparator.normalize_version(version)
|
||||
try:
|
||||
parts = normalized.split(".")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
except (ValueError, IndexError):
|
||||
return (0, 0, 0)
|
||||
|
||||
@staticmethod
|
||||
def compare(v1: str, v2: str) -> int:
|
||||
t1 = VersionComparator.parse_version(v1)
|
||||
t2 = VersionComparator.parse_version(v2)
|
||||
if t1 < t2:
|
||||
return -1
|
||||
elif t1 > t2:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]:
|
||||
if not min_version and not max_version:
|
||||
return True, ""
|
||||
vn = VersionComparator.normalize_version(version)
|
||||
if min_version:
|
||||
mn = VersionComparator.normalize_version(min_version)
|
||||
if VersionComparator.compare(vn, mn) < 0:
|
||||
return False, f"版本 {vn} 低于最小要求 {mn}"
|
||||
if max_version:
|
||||
mx = VersionComparator.normalize_version(max_version)
|
||||
if VersionComparator.compare(vn, mx) > 0:
|
||||
return False, f"版本 {vn} 高于最大支持 {mx}"
|
||||
return True, ""
|
||||
|
||||
|
||||
class ManifestValidator:
|
||||
"""_manifest.json 校验器"""
|
||||
|
||||
REQUIRED_FIELDS = ["name", "version", "description", "author"]
|
||||
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
|
||||
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
|
||||
|
||||
def __init__(self, host_version: str = ""):
|
||||
self._host_version = host_version
|
||||
self.errors: list[str] = []
|
||||
self.warnings: list[str] = []
|
||||
|
||||
def validate(self, manifest: dict[str, Any]) -> bool:
|
||||
"""校验 manifest 数据,返回是否通过(errors 为空即通过)。"""
|
||||
self.errors.clear()
|
||||
self.warnings.clear()
|
||||
|
||||
self._check_required_fields(manifest)
|
||||
self._check_manifest_version(manifest)
|
||||
self._check_author(manifest)
|
||||
self._check_host_compatibility(manifest)
|
||||
self._check_recommended(manifest)
|
||||
|
||||
if self.errors:
|
||||
for e in self.errors:
|
||||
logger.error(f"Manifest 校验失败: {e}")
|
||||
if self.warnings:
|
||||
for w in self.warnings:
|
||||
logger.warning(f"Manifest 警告: {w}")
|
||||
|
||||
return len(self.errors) == 0
|
||||
|
||||
def _check_required_fields(self, manifest: dict[str, Any]) -> None:
|
||||
for field in self.REQUIRED_FIELDS:
|
||||
if field not in manifest:
|
||||
self.errors.append(f"缺少必需字段: {field}")
|
||||
elif not manifest[field]:
|
||||
self.errors.append(f"必需字段不能为空: {field}")
|
||||
|
||||
def _check_manifest_version(self, manifest: dict[str, Any]) -> None:
|
||||
mv = manifest.get("manifest_version")
|
||||
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
|
||||
self.errors.append(
|
||||
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
|
||||
)
|
||||
|
||||
def _check_author(self, manifest: dict[str, Any]) -> None:
|
||||
author = manifest.get("author")
|
||||
if author is None:
|
||||
return
|
||||
if isinstance(author, dict):
|
||||
if "name" not in author or not author["name"]:
|
||||
self.errors.append("author 对象缺少 name 字段")
|
||||
elif isinstance(author, str):
|
||||
if not author.strip():
|
||||
self.errors.append("author 不能为空")
|
||||
else:
|
||||
self.errors.append("author 应为字符串或 {name, url} 对象")
|
||||
|
||||
def _check_host_compatibility(self, manifest: dict[str, Any]) -> None:
|
||||
host_app = manifest.get("host_application")
|
||||
if not isinstance(host_app, dict) or not self._host_version:
|
||||
return
|
||||
min_v = host_app.get("min_version", "")
|
||||
max_v = host_app.get("max_version", "")
|
||||
ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v)
|
||||
if not ok:
|
||||
self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
|
||||
|
||||
def _check_recommended(self, manifest: dict[str, Any]) -> None:
|
||||
for field in self.RECOMMENDED_FIELDS:
|
||||
if field not in manifest or not manifest[field]:
|
||||
self.warnings.append(f"建议填写字段: {field}")
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
在 Runner 进程中负责发现和加载插件。
|
||||
插件通过 SDK 编写,不再 import src.*。
|
||||
支持:manifest 校验、依赖解析(拓扑排序)、生命周期钩子。
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
import importlib
|
||||
@@ -13,6 +15,8 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.runner.plugin_loader")
|
||||
|
||||
|
||||
@@ -32,6 +36,20 @@ class PluginMeta:
|
||||
self.manifest = manifest
|
||||
self.version = manifest.get("version", "1.0.0")
|
||||
self.capabilities_required = manifest.get("capabilities", [])
|
||||
self.dependencies: list[str] = self._extract_dependencies(manifest)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dependencies(manifest: dict[str, Any]) -> list[str]:
|
||||
raw = manifest.get("dependencies", [])
|
||||
result: list[str] = []
|
||||
for dep in raw:
|
||||
if isinstance(dep, str):
|
||||
result.append(dep.strip())
|
||||
elif isinstance(dep, dict):
|
||||
name = str(dep.get("name", "")).strip()
|
||||
if name:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
@@ -43,19 +61,22 @@ class PluginLoader:
|
||||
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, host_version: str = ""):
|
||||
self._loaded_plugins: dict[str, PluginMeta] = {}
|
||||
self._failed_plugins: dict[str, str] = {}
|
||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||
|
||||
def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]:
|
||||
"""扫描多个目录并加载所有插件
|
||||
"""扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
|
||||
|
||||
Args:
|
||||
plugin_dirs: 插件目录列表
|
||||
|
||||
Returns:
|
||||
成功加载的插件元数据列表
|
||||
成功加载的插件元数据列表(按依赖顺序)
|
||||
"""
|
||||
results = []
|
||||
# 第一阶段:发现并校验 manifest
|
||||
candidates: dict[str, tuple[str, dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path)
|
||||
for base_dir in plugin_dirs:
|
||||
if not os.path.isdir(base_dir):
|
||||
logger.warning(f"插件目录不存在: {base_dir}")
|
||||
@@ -73,12 +94,40 @@ class PluginLoader:
|
||||
continue
|
||||
|
||||
try:
|
||||
meta = self._load_single_plugin(plugin_dir, manifest_path, plugin_path)
|
||||
if meta:
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
results.append(meta)
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件失败 [{plugin_dir}]: {e}", exc_info=True)
|
||||
self._failed_plugins[entry] = f"manifest 解析失败: {e}"
|
||||
logger.error(f"插件 {entry} manifest 解析失败: {e}")
|
||||
continue
|
||||
|
||||
if not self._manifest_validator.validate(manifest):
|
||||
errors = "; ".join(self._manifest_validator.errors)
|
||||
self._failed_plugins[entry] = f"manifest 校验失败: {errors}"
|
||||
continue
|
||||
|
||||
plugin_id = manifest.get("name", entry)
|
||||
candidates[plugin_id] = (plugin_dir, manifest, plugin_path)
|
||||
|
||||
# 第二阶段:依赖解析(拓扑排序)
|
||||
load_order, failed_deps = self._resolve_dependencies(candidates)
|
||||
|
||||
for pid, reason in failed_deps.items():
|
||||
self._failed_plugins[pid] = reason
|
||||
logger.error(f"插件 {pid} 依赖解析失败: {reason}")
|
||||
|
||||
# 第三阶段:按依赖顺序加载
|
||||
results = []
|
||||
for plugin_id in load_order:
|
||||
plugin_dir, manifest, plugin_path = candidates[plugin_id]
|
||||
try:
|
||||
meta = self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
|
||||
if meta:
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
results.append(meta)
|
||||
except Exception as e:
|
||||
self._failed_plugins[plugin_id] = str(e)
|
||||
logger.error(f"加载插件失败 [{plugin_id}]: {e}", exc_info=True)
|
||||
|
||||
return results
|
||||
|
||||
@@ -90,15 +139,78 @@ class PluginLoader:
|
||||
"""列出所有已加载的插件 ID"""
|
||||
return list(self._loaded_plugins.keys())
|
||||
|
||||
def _load_single_plugin(self, plugin_dir: str, manifest_path: str, plugin_path: str) -> PluginMeta | None:
|
||||
@property
|
||||
def failed_plugins(self) -> dict[str, str]:
|
||||
return dict(self._failed_plugins)
|
||||
|
||||
# ──── 依赖解析 ────────────────────────────────────────────
|
||||
|
||||
def _resolve_dependencies(
|
||||
self,
|
||||
candidates: dict[str, tuple[str, dict[str, Any], str]],
|
||||
) -> tuple[list[str], dict[str, str]]:
|
||||
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
|
||||
available = set(candidates.keys())
|
||||
dep_graph: dict[str, set[str]] = {}
|
||||
failed: dict[str, str] = {}
|
||||
|
||||
for pid, (_, manifest, _) in candidates.items():
|
||||
raw_deps = manifest.get("dependencies", [])
|
||||
resolved: set[str] = set()
|
||||
missing: list[str] = []
|
||||
for dep in raw_deps:
|
||||
dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
|
||||
dep_name = dep_name.strip()
|
||||
if not dep_name or dep_name == pid:
|
||||
continue
|
||||
if dep_name in available:
|
||||
resolved.add(dep_name)
|
||||
else:
|
||||
missing.append(dep_name)
|
||||
if missing:
|
||||
failed[pid] = f"缺少依赖: {', '.join(missing)}"
|
||||
dep_graph[pid] = resolved
|
||||
|
||||
# 移除失败项
|
||||
for pid in failed:
|
||||
dep_graph.pop(pid, None)
|
||||
|
||||
# Kahn 拓扑排序
|
||||
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
|
||||
reverse: dict[str, set[str]] = {pid: set() for pid in dep_graph}
|
||||
for pid, deps in dep_graph.items():
|
||||
for d in deps:
|
||||
if d in reverse:
|
||||
reverse[d].add(pid)
|
||||
|
||||
queue = deque(sorted(pid for pid, deg in indegree.items() if deg == 0))
|
||||
sorted_order: list[str] = []
|
||||
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
sorted_order.append(current)
|
||||
for dependent in sorted(reverse.get(current, [])):
|
||||
indegree[dependent] -= 1
|
||||
if indegree[dependent] == 0:
|
||||
queue.append(dependent)
|
||||
|
||||
cycle_plugins = {pid for pid, deg in indegree.items() if deg > 0}
|
||||
for pid in cycle_plugins:
|
||||
failed[pid] = "检测到循环依赖"
|
||||
|
||||
return sorted_order, failed
|
||||
|
||||
# ──── 单个插件加载 ────────────────────────────────────────
|
||||
|
||||
def _load_single_plugin(
|
||||
self,
|
||||
plugin_id: str,
|
||||
plugin_dir: str,
|
||||
manifest: dict[str, Any],
|
||||
plugin_path: str,
|
||||
) -> PluginMeta | None:
|
||||
"""加载单个插件"""
|
||||
# 1. 读取 manifest
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
|
||||
plugin_id = os.path.basename(plugin_dir)
|
||||
|
||||
# 2. 动态导入插件模块
|
||||
# 动态导入插件模块
|
||||
module_name = f"_maibot_plugin_{plugin_id}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
@@ -109,7 +221,7 @@ class PluginLoader:
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 3. 调用工厂函数创建插件实例
|
||||
# 调用工厂函数创建插件实例
|
||||
create_plugin = getattr(module, "create_plugin", None)
|
||||
if create_plugin is None:
|
||||
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数")
|
||||
|
||||
@@ -14,7 +14,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from src.plugin_runtime.protocol.codec import Codec, create_codec
|
||||
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
PROTOCOL_VERSION,
|
||||
Envelope,
|
||||
@@ -49,7 +49,7 @@ class RPCClient:
|
||||
):
|
||||
self._host_address = host_address
|
||||
self._session_token = session_token
|
||||
self._codec = codec or create_codec()
|
||||
self._codec = codec or MsgPackCodec()
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Connection | None = None
|
||||
|
||||
@@ -71,7 +71,23 @@ class PluginRunner:
|
||||
plugins = self._loader.discover_and_load(self._plugin_dirs)
|
||||
logger.info(f"已加载 {len(plugins)} 个插件")
|
||||
|
||||
# 4. 向 Host 注册所有插件的组件
|
||||
# 4. 调用 on_load 生命周期钩子 + 注入 RPC 客户端供 SDK context 使用
|
||||
for meta in plugins:
|
||||
instance = meta.instance
|
||||
# 注入 _rpc_client 以便 PluginContext 可以发起能力调用
|
||||
if hasattr(instance, "_ctx"):
|
||||
ctx = instance._ctx
|
||||
if hasattr(ctx, "_set_rpc_client"):
|
||||
ctx._set_rpc_client(self._rpc_client)
|
||||
if hasattr(instance, "on_load"):
|
||||
try:
|
||||
ret = instance.on_load()
|
||||
if asyncio.iscoroutine(ret):
|
||||
await ret
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {meta.plugin_id} on_load 失败: {e}", exc_info=True)
|
||||
|
||||
# 5. 向 Host 注册所有插件的组件
|
||||
for meta in plugins:
|
||||
await self._register_plugin(meta)
|
||||
|
||||
@@ -92,6 +108,7 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.emit_event", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
|
||||
self._rpc_client.register_method("plugin.health", self._handle_health)
|
||||
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
@@ -169,6 +186,57 @@ class PluginRunner:
|
||||
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
|
||||
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 WorkflowStep 调用请求
|
||||
|
||||
与通用 invoke 不同,会将返回值规范化为
|
||||
{hook_result, modified_message, stage_output} 格式。
|
||||
"""
|
||||
try:
|
||||
invoke = InvokePayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
f"插件 {plugin_id} 未加载",
|
||||
)
|
||||
|
||||
instance = meta.instance
|
||||
component_name = invoke.component_name
|
||||
handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
|
||||
|
||||
if handler_method is None or not callable(handler_method):
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {plugin_id} 无组件: {component_name}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
||||
|
||||
# 规范化返回值
|
||||
if raw is None:
|
||||
result = {"hook_result": "continue"}
|
||||
elif isinstance(raw, str):
|
||||
# 允许直接返回 hook_result 字符串
|
||||
result = {"hook_result": raw}
|
||||
elif isinstance(raw, dict):
|
||||
result = raw
|
||||
result.setdefault("hook_result", "continue")
|
||||
else:
|
||||
result = {"hook_result": "continue"}
|
||||
|
||||
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True)
|
||||
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
|
||||
async def _handle_health(self, envelope: Envelope) -> Envelope:
|
||||
"""处理健康检查"""
|
||||
uptime_ms = int((time.monotonic() - self._start_time) * 1000)
|
||||
@@ -185,8 +253,17 @@ class PluginRunner:
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
|
||||
"""处理关停"""
|
||||
logger.info("收到 shutdown 信号,准备退出")
|
||||
"""处理关停 — 调用所有插件的 on_unload 后退出"""
|
||||
logger.info("收到 shutdown 信号,开始调用 on_unload")
|
||||
for plugin_id in self._loader.list_plugins():
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta and hasattr(meta.instance, "on_unload"):
|
||||
try:
|
||||
ret = meta.instance.on_unload()
|
||||
if asyncio.iscoroutine(ret):
|
||||
await ret
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
|
||||
self._shutting_down = True
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
@@ -209,6 +286,43 @@ class PluginRunner:
|
||||
return self._rpc_client
|
||||
|
||||
|
||||
# ─── sys.path 隔离 ────────────────────────────────────────
|
||||
|
||||
def _isolate_sys_path(plugin_dirs: list[str]) -> None:
|
||||
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
|
||||
|
||||
防止插件代码 import 主程序模块读取运行时数据。
|
||||
"""
|
||||
import sysconfig
|
||||
|
||||
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
|
||||
stdlib_paths = set()
|
||||
for key in ("stdlib", "platstdlib", "purelib", "platlib"):
|
||||
path = sysconfig.get_path(key)
|
||||
if path:
|
||||
stdlib_paths.add(os.path.normpath(path))
|
||||
|
||||
allowed = set()
|
||||
for p in sys.path:
|
||||
norm = os.path.normpath(p)
|
||||
# 保留标准库和 site-packages
|
||||
if any(norm.startswith(sp) for sp in stdlib_paths):
|
||||
allowed.add(p)
|
||||
# 保留 site-packages(第三方库 + SDK)
|
||||
if "site-packages" in norm or "dist-packages" in norm:
|
||||
allowed.add(p)
|
||||
|
||||
# 添加插件目录
|
||||
for d in plugin_dirs:
|
||||
allowed.add(os.path.normpath(d))
|
||||
|
||||
# 添加当前 runner 模块所在路径(使得 src.plugin_runtime 可导入)
|
||||
runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
allowed.add(runtime_root)
|
||||
|
||||
sys.path[:] = [p for p in sys.path if p in allowed]
|
||||
|
||||
|
||||
# ─── 进程入口 ──────────────────────────────────────────────
|
||||
|
||||
async def _async_main() -> None:
|
||||
@@ -223,6 +337,9 @@ async def _async_main() -> None:
|
||||
|
||||
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
|
||||
|
||||
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
|
||||
_isolate_sys_path(plugin_dirs)
|
||||
|
||||
runner = PluginRunner(host_address, session_token, plugin_dirs)
|
||||
|
||||
# 注册信号处理
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Transport 层 - 跨平台本地 IPC 传输抽象
|
||||
|
||||
|
||||
Reference in New Issue
Block a user