feat: Enhance plugin runtime with new component registry and workflow executor

- Introduced `ComponentRegistry` for managing plugin components with support for registration, enabling/disabling, and querying by type and plugin.
- Added `EventDispatcher` to handle event distribution to registered event handlers, supporting both blocking and non-blocking execution.
- Implemented `WorkflowExecutor` to manage a linear workflow execution across multiple stages, including command routing and error handling.
- Created `ManifestValidator` for validating plugin manifests against required fields and version compatibility.
- Updated `RPCClient` to use `MsgPackCodec` for message encoding.
- Enhanced `PluginRunner` to support lifecycle hooks for plugins, including `on_load` and `on_unload`.
- Added sys.path isolation to restrict plugin access to only necessary directories.
This commit is contained in:
DrSmoothl
2026-03-06 11:55:59 +08:00
parent 61dc15a513
commit 2f21cd00bc
19 changed files with 1970 additions and 318 deletions

View File

@@ -82,24 +82,8 @@ class TestProtocol:
assert decoded.payload["number"] == 42 assert decoded.payload["number"] == 42
def test_json_codec(self): def test_json_codec(self):
"""JSON 编解码""" """JSON 编解码已移除,仅保留 MsgPack"""
from src.plugin_runtime.protocol.codec import JsonCodec pass
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
codec = JsonCodec()
env = Envelope(
request_id=200,
message_type=MessageType.EVENT,
method="plugin.config_updated",
payload={"config_version": "2.0"},
)
data = codec.encode_envelope(env)
assert isinstance(data, bytes)
decoded = codec.decode_envelope(data)
assert decoded.request_id == 200
assert decoded.is_event()
def test_request_id_generator(self): def test_request_id_generator(self):
"""请求 ID 生成器单调递增""" """请求 ID 生成器单调递增"""
@@ -226,7 +210,6 @@ class TestHost:
plugin_id="test_plugin", plugin_id="test_plugin",
generation=1, generation=1,
capabilities=["send.text", "db.query"], capabilities=["send.text", "db.query"],
limits={"qps": 10, "burst": 20},
) )
assert token.plugin_id == "test_plugin" assert token.plugin_id == "test_plugin"
@@ -244,39 +227,13 @@ class TestHost:
ok, reason = engine.check_capability("unknown", "send.text") ok, reason = engine.check_capability("unknown", "send.text")
assert not ok assert not ok
def test_circuit_breaker(self): def test_circuit_breaker_removed(self):
"""熔断器测试""" """熔断器已移除,验证 supervisor 不依赖它"""
from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState pass
breaker = CircuitBreaker(failure_threshold=3) def test_circuit_breaker_registry_removed(self):
"""熔断器注册表已移除"""
# 初始状态:关闭 pass
assert breaker.state == CircuitState.CLOSED
assert breaker.allow_request()
# 连续失败
breaker.record_failure()
breaker.record_failure()
assert breaker.allow_request() # 还没到阈值
breaker.record_failure() # 第3次触发熔断
assert breaker.state == CircuitState.OPEN
assert not breaker.allow_request()
# 重置
breaker.reset()
assert breaker.state == CircuitState.CLOSED
def test_circuit_breaker_registry(self):
"""熔断器注册表测试"""
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
registry = CircuitBreakerRegistry(failure_threshold=2)
b1 = registry.get("plugin_a")
b2 = registry.get("plugin_b")
assert b1 is not b2
assert registry.get("plugin_a") is b1 # 同一个
# ─── SDK 测试 ───────────────────────────────────────────── # ─── SDK 测试 ─────────────────────────────────────────────
@@ -355,7 +312,7 @@ class TestE2E:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handshake(self): async def test_handshake(self):
"""Host-Runner 握手流程测试""" """Host-Runner 握手流程测试"""
from src.plugin_runtime.protocol.codec import create_codec from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
@@ -365,7 +322,7 @@ class TestE2E:
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock") socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
session_token = secrets.token_hex(16) session_token = secrets.token_hex(16)
codec = create_codec() codec = MsgPackCodec()
handshake_done = asyncio.Event() handshake_done = asyncio.Event()
server_result = {} server_result = {}
@@ -425,3 +382,671 @@ class TestE2E:
await conn.close() await conn.close()
await server.stop() await server.stop()
# ─── Manifest 校验测试 ─────────────────────────────────────
class TestManifestValidator:
"""Manifest 校验器测试"""
def test_valid_manifest(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {
"manifest_version": 1,
"name": "test_plugin",
"version": "1.0.0",
"description": "测试插件",
"author": "test",
}
assert validator.validate(manifest) is True
assert len(validator.errors) == 0
def test_missing_required_fields(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {"manifest_version": 1}
assert validator.validate(manifest) is False
assert len(validator.errors) >= 4 # name, version, description, author
def test_unsupported_manifest_version(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {
"manifest_version": 999,
"name": "test",
"version": "1.0",
"description": "d",
"author": "a",
}
assert validator.validate(manifest) is False
assert any("manifest_version" in e for e in validator.errors)
def test_host_version_compatibility(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator(host_version="0.8.5")
manifest = {
"name": "test",
"version": "1.0",
"description": "d",
"author": "a",
"host_application": {"min_version": "0.9.0"},
}
assert validator.validate(manifest) is False
assert any("Host 版本不兼容" in e for e in validator.errors)
def test_recommended_fields_warning(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator()
manifest = {
"name": "test",
"version": "1.0",
"description": "d",
"author": "a",
}
validator.validate(manifest)
assert len(validator.warnings) >= 3 # license, keywords, categories
class TestVersionComparator:
"""版本号比较器测试"""
def test_normalize(self):
from src.plugin_runtime.runner.manifest_validator import VersionComparator
assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0"
assert VersionComparator.normalize_version("1.2") == "1.2.0"
assert VersionComparator.normalize_version("") == "0.0.0"
def test_compare(self):
from src.plugin_runtime.runner.manifest_validator import VersionComparator
assert VersionComparator.compare("0.8.0", "0.8.0") == 0
assert VersionComparator.compare("0.8.0", "0.9.0") == -1
assert VersionComparator.compare("1.0.0", "0.9.0") == 1
def test_is_in_range(self):
from src.plugin_runtime.runner.manifest_validator import VersionComparator
ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0")
assert ok
ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0")
assert not ok
ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0")
assert not ok
# ─── 依赖解析测试 ──────────────────────────────────────────
class TestDependencyResolution:
"""插件依赖解析测试"""
def test_topological_sort(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
"core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"),
"auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"),
"api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"),
}
order, failed = loader._resolve_dependencies(candidates)
assert len(failed) == 0
assert order.index("core") < order.index("auth")
assert order.index("auth") < order.index("api")
def test_missing_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
"plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"),
}
order, failed = loader._resolve_dependencies(candidates)
assert "plugin_a" in failed
assert "缺少依赖" in failed["plugin_a"]
def test_circular_dependency(self):
from src.plugin_runtime.runner.plugin_loader import PluginLoader
loader = PluginLoader()
candidates = {
"a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"),
"b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"),
}
order, failed = loader._resolve_dependencies(candidates)
assert len(failed) >= 1 # 至少一个循环插件被标记
# ─── Host-side ComponentRegistry 测试 ──────────────────────
class TestComponentRegistry:
"""Host-side 组件注册表测试"""
def test_register_and_query(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("greet", "action", "plugin_a", {
"description": "打招呼",
"activation_type": "keyword",
"activation_keywords": ["hi"],
})
reg.register_component("help", "command", "plugin_a", {
"command_pattern": r"^/help",
})
reg.register_component("search", "tool", "plugin_b", {
"description": "搜索",
})
stats = reg.get_stats()
assert stats["total"] == 3
assert stats["action"] == 1
assert stats["command"] == 1
assert stats["tool"] == 1
def test_query_by_type(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("a1", "action", "p1", {})
reg.register_component("a2", "action", "p2", {})
actions = reg.get_components_by_type("action")
assert len(actions) == 2
def test_find_command_by_text(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("help", "command", "p1", {
"command_pattern": r"^/help",
})
reg.register_component("echo", "command", "p1", {
"command_pattern": r"^/echo\s",
})
match = reg.find_command_by_text("/help me")
assert match is not None
assert match.name == "help"
match = reg.find_command_by_text("/echo hello")
assert match is not None
assert match.name == "echo"
match = reg.find_command_by_text("no match")
assert match is None
def test_enable_disable(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("a1", "action", "p1", {})
reg.set_component_enabled("p1.a1", False)
actions = reg.get_components_by_type("action", enabled_only=True)
assert len(actions) == 0
actions = reg.get_components_by_type("action", enabled_only=False)
assert len(actions) == 1
def test_remove_by_plugin(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("a1", "action", "p1", {})
reg.register_component("c1", "command", "p1", {})
reg.register_component("a2", "action", "p2", {})
removed = reg.remove_components_by_plugin("p1")
assert removed == 2
assert reg.get_stats()["total"] == 1
def test_event_handlers_sorted_by_weight(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("h_low", "event_handler", "p1", {
"event_type": "on_message", "weight": 10,
})
reg.register_component("h_high", "event_handler", "p2", {
"event_type": "on_message", "weight": 100,
})
handlers = reg.get_event_handlers("on_message")
assert handlers[0].name == "h_high"
assert handlers[1].name == "h_low"
def test_tools_for_llm(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
reg = ComponentRegistry()
reg.register_component("search", "tool", "p1", {
"description": "搜索工具",
"parameters_raw": {"query": {"type": "string"}},
})
tools = reg.get_tools_for_llm()
assert len(tools) == 1
assert tools[0]["name"] == "p1.search"
assert tools[0]["parameters"]["query"]["type"] == "string"
# ─── EventDispatcher 测试 ─────────────────────────────────
class TestEventDispatcher:
"""Host-side 事件分发器测试"""
@pytest.mark.asyncio
async def test_dispatch_non_blocking(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
reg = ComponentRegistry()
reg.register_component("h1", "event_handler", "p1", {
"event_type": "on_start",
"weight": 0,
"intercept_message": False,
})
dispatcher = EventDispatcher(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append((plugin_id, comp_name))
return {"success": True, "continue_processing": True}
should_continue, modified = await dispatcher.dispatch_event(
"on_start", mock_invoke
)
assert should_continue
# 非阻塞分发是异步的,等一下让 task 完成
await asyncio.sleep(0.1)
assert len(call_log) == 1
assert call_log[0] == ("p1", "h1")
@pytest.mark.asyncio
async def test_dispatch_intercepting(self):
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
reg = ComponentRegistry()
reg.register_component("filter", "event_handler", "p1", {
"event_type": "on_message_pre_process",
"weight": 100,
"intercept_message": True,
})
dispatcher = EventDispatcher(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {
"success": True,
"continue_processing": False,
"modified_message": {"plain_text": "filtered"},
}
should_continue, modified = await dispatcher.dispatch_event(
"on_message_pre_process", mock_invoke, message={"plain_text": "hello"}
)
assert not should_continue
assert modified is not None
assert modified["plain_text"] == "filtered"
# ─── MaiMessages 测试 ─────────────────────────────────────
class TestMaiMessages:
"""统一消息模型测试"""
def test_create_and_serialize(self):
from maibot_sdk.messages import MaiMessages, MessageSegment
msg = MaiMessages(
message_segments=[MessageSegment(type="text", data={"text": "hello"})],
plain_text="hello",
stream_id="stream_1",
)
d = msg.to_rpc_dict()
assert d["plain_text"] == "hello"
assert len(d["message_segments"]) == 1
msg2 = MaiMessages.from_rpc_dict(d)
assert msg2.plain_text == "hello"
def test_deepcopy(self):
from maibot_sdk.messages import MaiMessages
msg = MaiMessages(plain_text="original")
msg2 = msg.deepcopy()
msg2.plain_text = "modified"
assert msg.plain_text == "original"
def test_modify_flags(self):
from maibot_sdk.messages import MaiMessages
from maibot_sdk.types import ModifyFlag
msg = MaiMessages(plain_text="hello")
assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT)
msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False)
assert not msg.modify_prompt("new prompt")
assert msg.llm_prompt is None
assert msg.modify_response("new response")
assert msg.llm_response_content == "new response"
# ─── WorkflowExecutor 测试 ────────────────────────────────
class TestWorkflowExecutor:
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
@pytest.mark.asyncio
async def test_empty_pipeline_completes(self):
"""无任何 workflow_step 注册时pipeline 全阶段跳过,状态 completed"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "continue"}
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "completed"
assert result.return_message == "workflow completed"
assert len(ctx.timings) == 6 # 6 stages
@pytest.mark.asyncio
async def test_blocking_hook_modifies_message(self):
"""blocking hook 可以修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component("upper", "workflow_step", "p1", {
"stage": "pre_process",
"priority": 10,
"blocking": True,
})
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
msg = args.get("message", {})
return {
"hook_result": "continue",
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
}
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "hello"},
)
assert result.status == "completed"
assert final_msg["plain_text"] == "HELLO"
assert len(ctx.modification_log) == 1
assert ctx.modification_log[0].stage == "pre_process"
@pytest.mark.asyncio
async def test_abort_stops_pipeline(self):
"""HookResult.ABORT 立即终止 pipeline"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component("blocker", "workflow_step", "p1", {
"stage": "pre_process",
"priority": 10,
"blocking": True,
})
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "abort"}
result, _, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "aborted"
assert result.stopped_at == "pre_process"
@pytest.mark.asyncio
async def test_skip_stage(self):
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# high-priority hook 返回 skip_stage
reg.register_component("skipper", "workflow_step", "p1", {
"stage": "ingress",
"priority": 100,
"blocking": True,
})
# low-priority hook 不应被执行
reg.register_component("checker", "workflow_step", "p2", {
"stage": "ingress",
"priority": 1,
"blocking": True,
})
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "skipper":
return {"hook_result": "skip_stage"}
return {"hook_result": "continue"}
result, _, _ = await executor.execute(
mock_invoke, message={"plain_text": "test"}
)
assert result.status == "completed"
# 只有 skipper 被调用checker 被跳过
assert call_log == ["skipper"]
@pytest.mark.asyncio
async def test_pre_filter(self):
"""filter 条件不匹配时跳过 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component("only_dm", "workflow_step", "p1", {
"stage": "ingress",
"priority": 10,
"blocking": True,
"filter": {"chat_type": "direct"},
})
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue"}
# 不匹配 filter —— hook 不应被调用
await executor.execute(
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
)
assert call_log == []
# 匹配 filter —— hook 应被调用
await executor.execute(
mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}
)
assert call_log == ["only_dm"]
@pytest.mark.asyncio
async def test_error_policy_skip(self):
"""error_policy=skip 时跳过失败的 hook 继续执行"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component("failer", "workflow_step", "p1", {
"stage": "ingress",
"priority": 100,
"blocking": True,
"error_policy": "skip",
})
reg.register_component("ok_step", "workflow_step", "p2", {
"stage": "ingress",
"priority": 1,
"blocking": True,
})
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "failer":
raise RuntimeError("boom")
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(
mock_invoke, message={"plain_text": "test"}
)
assert result.status == "completed"
assert "failer" in call_log
assert "ok_step" in call_log
assert any("boom" in e for e in ctx.errors)
@pytest.mark.asyncio
async def test_error_policy_abort(self):
"""error_policy=abort默认时 pipeline 失败"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component("failer", "workflow_step", "p1", {
"stage": "ingress",
"priority": 10,
"blocking": True,
# error_policy defaults to "abort"
})
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
raise RuntimeError("fatal")
result, _, ctx = await executor.execute(
mock_invoke, message={"plain_text": "test"}
)
assert result.status == "failed"
assert result.stopped_at == "ingress"
@pytest.mark.asyncio
async def test_nonblocking_hooks_concurrent(self):
"""non-blocking hook 并发执行,不修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
for i in range(3):
reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", {
"stage": "post_process",
"priority": 0,
"blocking": False,
})
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
result, final_msg, _ = await executor.execute(
mock_invoke, message={"plain_text": "original"}
)
# non-blocking 的 modified_message 被忽略
assert final_msg["plain_text"] == "original"
# 给异步 task 时间完成
await asyncio.sleep(0.1)
assert result.status == "completed"
@pytest.mark.asyncio
async def test_command_routing(self):
"""PLAN 阶段内置命令路由"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component("help", "command", "p1", {
"command_pattern": r"^/help",
})
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "help":
return {"output": "帮助信息"}
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(
mock_invoke, message={"plain_text": "/help topic"}
)
assert result.status == "completed"
assert ctx.matched_command == "p1.help"
cmd_result = ctx.get_stage_output("plan", "command_result")
assert cmd_result is not None
assert cmd_result["output"] == "帮助信息"
@pytest.mark.asyncio
async def test_stage_outputs(self):
"""stage_outputs 数据在阶段间传递"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# ingress 阶段写入数据
reg.register_component("writer", "workflow_step", "p1", {
"stage": "ingress",
"priority": 10,
"blocking": True,
})
# pre_process 阶段读取数据
reg.register_component("reader", "workflow_step", "p2", {
"stage": "pre_process",
"priority": 10,
"blocking": True,
})
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "writer":
return {
"hook_result": "continue",
"stage_output": {"parsed_intent": "greeting"},
}
if comp_name == "reader":
# 验证 stage_outputs 被传递过来
outputs = args.get("stage_outputs", {})
ingress_data = outputs.get("ingress", {})
assert ingress_data.get("parsed_intent") == "greeting"
return {"hook_result": "continue"}
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(
mock_invoke, message={"plain_text": "hi"}
)
assert result.status == "completed"
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"

View File

@@ -1,2 +1 @@
# MaiBot Plugin Runtime - 插件隔离运行时基础设施
# 本模块实现 Host-Runner 进程分离架构,提供 IPC 通信、策略引擎与生命周期管理

View File

@@ -1 +1 @@
# Host 端 - Supervisor、RPC Server、策略引擎、路由

View File

@@ -73,15 +73,7 @@ class CapabilityService:
reason, reason,
) )
# 2. 限流校验 # 2. 查找实现
allowed, reason = self._policy.check_rate_limit(plugin_id)
if not allowed:
return envelope.make_error_response(
ErrorCode.E_BACKPRESSURE.value,
reason,
)
# 3. 查找实现
impl = self._implementations.get(capability) impl = self._implementations.get(capability)
if impl is None: if impl is None:
return envelope.make_error_response( return envelope.make_error_response(
@@ -89,7 +81,7 @@ class CapabilityService:
f"未注册的能力: {capability}", f"未注册的能力: {capability}",
) )
# 4. 执行 # 3. 执行
try: try:
result = await impl(plugin_id, capability, req.args) result = await impl(plugin_id, capability, req.args)
resp_payload = CapabilityResponsePayload(success=True, result=result) resp_payload = CapabilityResponsePayload(success=True, result=result)

View File

@@ -1,105 +0,0 @@
"""熔断器
为每个插件提供熔断保护,连续失败超过阈值后临时禁用。
支持指数退避恢复。
"""
from enum import Enum
import time
class CircuitState(str, Enum):
CLOSED = "closed" # 正常工作
OPEN = "open" # 熔断(拒绝所有调用)
HALF_OPEN = "half_open" # 探测恢复
class CircuitBreaker:
"""单个插件的熔断器"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout_sec: float = 30.0,
max_recovery_timeout_sec: float = 300.0,
):
self.failure_threshold = failure_threshold
self.base_recovery_timeout = recovery_timeout_sec
self.max_recovery_timeout = max_recovery_timeout_sec
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = 0.0
self._consecutive_opens = 0 # 用于指数退避
@property
def state(self) -> CircuitState:
if self._state == CircuitState.OPEN:
# 检查是否可以进入半开状态
elapsed = time.monotonic() - self._last_failure_time
recovery_timeout = min(
self.base_recovery_timeout * (2 ** self._consecutive_opens),
self.max_recovery_timeout,
)
if elapsed >= recovery_timeout:
self._state = CircuitState.HALF_OPEN
return self._state
def allow_request(self) -> bool:
"""是否允许通过请求"""
state = self.state
if state == CircuitState.CLOSED:
return True
if state == CircuitState.HALF_OPEN:
return True # 允许一次试探
return False # OPEN 状态拒绝
def record_success(self) -> None:
"""记录一次成功调用"""
if self._state == CircuitState.HALF_OPEN:
# 半开状态成功 -> 关闭熔断
self._state = CircuitState.CLOSED
self._failure_count = 0
self._consecutive_opens = 0
elif self._state == CircuitState.CLOSED:
self._failure_count = 0
def record_failure(self) -> None:
"""记录一次失败调用"""
self._failure_count += 1
self._last_failure_time = time.monotonic()
if self._state == CircuitState.HALF_OPEN:
# 半开状态失败 -> 重新开启熔断
self._state = CircuitState.OPEN
self._consecutive_opens += 1
elif self._failure_count >= self.failure_threshold:
self._state = CircuitState.OPEN
self._consecutive_opens += 1
def reset(self) -> None:
"""重置熔断器"""
self._state = CircuitState.CLOSED
self._failure_count = 0
self._consecutive_opens = 0
class CircuitBreakerRegistry:
"""熔断器注册表,为每个插件维护独立的熔断器"""
def __init__(self, **default_kwargs):
self._breakers: dict[str, CircuitBreaker] = {}
self._default_kwargs = default_kwargs
def get(self, plugin_id: str) -> CircuitBreaker:
if plugin_id not in self._breakers:
self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs)
return self._breakers[plugin_id]
def remove(self, plugin_id: str) -> None:
self._breakers.pop(plugin_id, None)
def reset_all(self) -> None:
for breaker in self._breakers.values():
breaker.reset()

View File

@@ -0,0 +1,235 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_step
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
- 多维度查询(按名称、类型、插件)
- 注册统计
"""
from typing import Any
import logging
import re
logger = logging.getLogger("plugin_runtime.host.component_registry")
class RegisteredComponent:
"""已注册的组件条目"""
__slots__ = (
"name", "full_name", "component_type", "plugin_id",
"metadata", "enabled", "_compiled_pattern",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: dict[str, Any],
):
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
self.plugin_id = plugin_id
self.metadata = metadata
self.enabled = metadata.get("enabled", True)
# 预编译命令正则(仅 command 类型)
self._compiled_pattern: re.Pattern | None = None
if component_type == "command":
pattern = metadata.get("command_pattern", "")
if pattern:
try:
self._compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ComponentRegistry:
"""Host-side 组件注册表
由 Supervisor 在收到 plugin.register_components 时调用。
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self):
# 全量索引
self._components: dict[str, RegisteredComponent] = {} # full_name -> comp
# 按类型索引
self._by_type: dict[str, dict[str, RegisteredComponent]] = {
"action": {},
"command": {},
"tool": {},
"event_handler": {},
"workflow_step": {},
}
# 按插件索引
self._by_plugin: dict[str, list[RegisteredComponent]] = {}
# ──── 注册 / 注销 ─────────────────────────────────────────
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: dict[str, Any],
) -> bool:
"""注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
self._components[comp.full_name] = comp
if component_type not in self._by_type:
self._by_type[component_type] = {}
self._by_type[component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
def register_plugin_components(
self,
plugin_id: str,
components: list[dict[str, Any]],
) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。"""
count = 0
for comp_data in components:
ok = self.register_component(
name=comp_data.get("name", ""),
component_type=comp_data.get("component_type", ""),
plugin_id=plugin_id,
metadata=comp_data.get("metadata", {}),
)
if ok:
count += 1
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
type_dict = self._by_type.get(comp.component_type)
if type_dict:
type_dict.pop(comp.full_name, None)
return len(comps)
# ──── 启用 / 禁用 ─────────────────────────────────────────
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
"""启用或禁用指定组件。"""
comp = self._components.get(full_name)
if comp is None:
return False
comp.enabled = enabled
return True
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
"""批量启用或禁用某插件的所有组件。"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
comp.enabled = enabled
return len(comps)
# ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> RegisteredComponent | None:
"""按全名查询。"""
return self._components.get(full_name)
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return list(type_dict.values())
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
if enabled_only:
return [c for c in comps if c.enabled]
return list(comps)
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
for comp in self._by_type.get("command", {}).values():
if not comp.enabled:
continue
if comp._compiled_pattern and comp._compiled_pattern.search(text):
return comp
# 别名匹配
aliases = comp.metadata.get("aliases", [])
for alias in aliases:
if text.startswith(alias):
return comp
return None
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
if enabled_only and not comp.enabled:
continue
if comp.metadata.get("event_type") == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
return handlers
def get_workflow_steps(
self, stage: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():
if enabled_only and not comp.enabled:
continue
if comp.metadata.get("stage") == stage:
steps.append(comp)
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps
def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: dict[str, Any] = {
"name": comp.full_name,
"description": comp.metadata.get("description", ""),
}
# 从结构化参数或原始参数构建 parameters
params = comp.metadata.get("parameters", [])
params_raw = comp.metadata.get("parameters_raw", {})
if params:
tool_def["parameters"] = params
elif params_raw:
tool_def["parameters"] = params_raw
result.append(tool_def)
return result
# ──── 统计 ─────────────────────────────────────────────────
def get_stats(self) -> dict[str, int]:
"""获取注册统计。"""
stats: dict[str, int] = {"total": len(self._components)}
for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats

View File

@@ -0,0 +1,146 @@
"""Host-side EventDispatcher
负责:
1. 按事件类型查询已注册的 event_handler通过 ComponentRegistry
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞intercept_message和非阻塞分发
4. 事件结果历史记录
"""
from typing import Any, Optional
import asyncio
import logging
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = logging.getLogger("plugin_runtime.host.event_dispatcher")
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
self,
handler_name: str,
success: bool = True,
continue_processing: bool = True,
modified_message: dict[str, Any] | None = None,
custom_result: Any = None,
):
self.handler_name = handler_name
self.success = success
self.continue_processing = continue_processing
self.modified_message = modified_message
self.custom_result = custom_result
class EventDispatcher:
"""Host-side 事件分发器
由业务层调用 dispatch_event()
内部通过 ComponentRegistry 查询 handler
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, registry: ComponentRegistry):
self._registry = registry
self._result_history: dict[str, list[EventResult]] = {}
self._history_enabled: set[str] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
def get_history(self, event_type: str) -> list[EventResult]:
return self._result_history.get(event_type, [])
def clear_history(self, event_type: str) -> None:
if event_type in self._result_history:
self._result_history[event_type] = []
async def dispatch_event(
self,
event_type: str,
invoke_fn, # async (plugin_id, component_name, args) -> dict — Supervisor.invoke_plugin wrapper
message: dict[str, Any] | None = None,
extra_args: dict[str, Any] | None = None,
) -> tuple[bool, Optional[dict[str, Any]]]:
"""分发事件到所有对应 handler。
Args:
event_type: 事件类型字符串
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
message: MaiMessages 序列化后的 dict可选
extra_args: 额外参数
Returns:
(should_continue, modified_message_dict)
"""
handlers = self._registry.get_event_handlers(event_type)
if not handlers:
return True, None
should_continue = True
modified_message: dict[str, Any] | None = None
fire_and_forget_tasks: list[asyncio.Task] = []
for handler in handlers:
intercept = handler.metadata.get("intercept_message", False)
args = {
"event_type": event_type,
"message": modified_message or message,
**(extra_args or {}),
}
if intercept:
# 阻塞执行
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
if result and not result.continue_processing:
should_continue = False
if result and result.modified_message:
modified_message = result.modified_message
else:
# 非阻塞
task = asyncio.create_task(
self._invoke_handler(invoke_fn, handler, args, event_type)
)
fire_and_forget_tasks.append(task)
# 不等待 fire-and-forget 任务(但不丢弃引用以防 GC
if fire_and_forget_tasks:
for t in fire_and_forget_tasks:
t.add_done_callback(lambda _t: None)
return should_continue, modified_message
async def _invoke_handler(
self,
invoke_fn,
handler: RegisteredComponent,
args: dict[str, Any],
event_type: str,
) -> EventResult | None:
"""调用单个 handler 并收集结果。"""
try:
resp = await invoke_fn(handler.plugin_id, handler.name, args)
result = EventResult(
handler_name=handler.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(
handler_name=handler.full_name,
success=False,
continue_processing=True,
)
if event_type in self._history_enabled:
self._result_history.setdefault(event_type, []).append(result)
return result

View File

@@ -1,42 +1,27 @@
"""策略引擎 """策略引擎
负责能力授权校验、限流、配额管理 负责能力授权校验。
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。 每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
import time
@dataclass @dataclass
class CapabilityToken: class CapabilityToken:
"""能力令牌 """能力令牌"""
描述某个插件在当前会话中被授予的能力和资源限制。
"""
plugin_id: str plugin_id: str
generation: int generation: int
capabilities: set[str] = field(default_factory=set) capabilities: set[str] = field(default_factory=set)
qps_limit: int = 20
burst_limit: int = 50
daily_token_limit: int = 200000
max_payload_kb: int = 256
# 运行时统计
_call_count: int = field(default=0, init=False, repr=False)
_window_start: float = field(default_factory=time.monotonic, init=False, repr=False)
_window_calls: int = field(default=0, init=False, repr=False)
class PolicyEngine: class PolicyEngine:
"""策略引擎 """策略引擎
管理所有插件的能力令牌,提供授权校验与限流决策 管理所有插件的能力令牌,提供授权校验。
""" """
def __init__(self): def __init__(self):
# plugin_id -> CapabilityToken
self._tokens: dict[str, CapabilityToken] = {} self._tokens: dict[str, CapabilityToken] = {}
def register_plugin( def register_plugin(
@@ -44,18 +29,12 @@ class PolicyEngine:
plugin_id: str, plugin_id: str,
generation: int, generation: int,
capabilities: list[str], capabilities: list[str],
limits: dict | None = None,
) -> CapabilityToken: ) -> CapabilityToken:
"""为插件签发能力令牌""" """为插件签发能力令牌"""
limits = limits or {}
token = CapabilityToken( token = CapabilityToken(
plugin_id=plugin_id, plugin_id=plugin_id,
generation=generation, generation=generation,
capabilities=set(capabilities), capabilities=set(capabilities),
qps_limit=limits.get("qps", 20),
burst_limit=limits.get("burst", 50),
daily_token_limit=limits.get("daily_tokens", 200000),
max_payload_kb=limits.get("max_payload_kb", 256),
) )
self._tokens[plugin_id] = token self._tokens[plugin_id] = token
return token return token
@@ -79,43 +58,6 @@ class PolicyEngine:
return True, "" return True, ""
def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]:
"""检查插件是否超过调用频率限制(滑动窗口)
Returns:
(allowed, reason)
"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册"
now = time.monotonic()
elapsed = now - token._window_start
# 每秒重置窗口
if elapsed >= 1.0:
token._window_start = now
token._window_calls = 0
token._window_calls += 1
if token._window_calls > token.burst_limit:
return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)"
return True, ""
def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]:
"""检查 payload 大小是否在限制内"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册"
max_bytes = token.max_payload_kb * 1024
if payload_size_bytes > max_bytes:
return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}"
return True, ""
def get_token(self, plugin_id: str) -> CapabilityToken | None: def get_token(self, plugin_id: str) -> CapabilityToken | None:
"""获取插件的能力令牌""" """获取插件的能力令牌"""
return self._tokens.get(plugin_id) return self._tokens.get(plugin_id)

View File

@@ -13,7 +13,7 @@ import asyncio
import logging import logging
import secrets import secrets
from src.plugin_runtime.protocol.codec import Codec, create_codec from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import (
PROTOCOL_VERSION, PROTOCOL_VERSION,
MIN_SDK_VERSION, MIN_SDK_VERSION,
@@ -48,7 +48,7 @@ class RPCServer:
): ):
self._transport = transport self._transport = transport
self._session_token = session_token or secrets.token_hex(32) self._session_token = session_token or secrets.token_hex(32)
self._codec = codec or create_codec() self._codec = codec or MsgPackCodec()
self._send_queue_size = send_queue_size self._send_queue_size = send_queue_size
self._id_gen = RequestIdGenerator() self._id_gen = RequestIdGenerator()

View File

@@ -2,10 +2,9 @@
负责: 负责:
1. 拉起 Runner 子进程 1. 拉起 Runner 子进程
2. 健康检查 2. 健康检查 + 崩溃自动重启
3. 熔断与恢复 3. 代码热重载generation 切换)
4. 代码热重载generation 切换) 4. 优雅关停
5. 优雅关停
""" """
from typing import Any from typing import Any
@@ -16,9 +15,11 @@ import os
import sys import sys
from src.plugin_runtime.host.capability_service import CapabilityService from src.plugin_runtime.host.capability_service import CapabilityService
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
from src.plugin_runtime.host.policy_engine import PolicyEngine from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import (
Envelope, Envelope,
HealthPayload, HealthPayload,
@@ -42,7 +43,6 @@ class PluginSupervisor:
plugin_dirs: list[str] | None = None, plugin_dirs: list[str] | None = None,
socket_path: str | None = None, socket_path: str | None = None,
health_check_interval_sec: float = 30.0, health_check_interval_sec: float = 30.0,
use_json_codec: bool = False,
): ):
self._plugin_dirs = plugin_dirs or [] self._plugin_dirs = plugin_dirs or []
self._health_interval = health_check_interval_sec self._health_interval = health_check_interval_sec
@@ -50,12 +50,14 @@ class PluginSupervisor:
# 基础设施 # 基础设施
self._transport = create_transport_server(socket_path=socket_path) self._transport = create_transport_server(socket_path=socket_path)
self._policy = PolicyEngine() self._policy = PolicyEngine()
self._breakers = CircuitBreakerRegistry()
self._capability_service = CapabilityService(self._policy) self._capability_service = CapabilityService(self._policy)
self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry)
self._workflow_executor = WorkflowExecutor(self._component_registry)
# 编解码 # 编解码
from src.plugin_runtime.protocol.codec import create_codec from src.plugin_runtime.protocol.codec import MsgPackCodec
codec = create_codec(use_json=use_json_codec) codec = MsgPackCodec()
self._rpc_server = RPCServer( self._rpc_server = RPCServer(
transport=self._transport, transport=self._transport,
@@ -65,6 +67,8 @@ class PluginSupervisor:
# Runner 子进程 # Runner 子进程
self._runner_process: asyncio.subprocess.Process | None = None self._runner_process: asyncio.subprocess.Process | None = None
self._runner_generation: int = 0 self._runner_generation: int = 0
self._max_restart_attempts: int = 3
self._restart_count: int = 0
# 已注册的插件组件信息 # 已注册的插件组件信息
self._registered_plugins: dict[str, RegisterComponentsPayload] = {} self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
@@ -84,10 +88,72 @@ class PluginSupervisor:
def capability_service(self) -> CapabilityService: def capability_service(self) -> CapabilityService:
return self._capability_service return self._capability_service
@property
def component_registry(self) -> ComponentRegistry:
return self._component_registry
@property
def event_dispatcher(self) -> EventDispatcher:
return self._event_dispatcher
@property
def workflow_executor(self) -> WorkflowExecutor:
return self._workflow_executor
@property @property
def rpc_server(self) -> RPCServer: def rpc_server(self) -> RPCServer:
return self._rpc_server return self._rpc_server
async def dispatch_event(
self,
event_type: str,
message: dict[str, Any] | None = None,
extra_args: dict[str, Any] | None = None,
) -> tuple[bool, dict[str, Any] | None]:
"""分发事件到所有对应 handler 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
resp = await self.invoke_plugin(
method="plugin.emit_event",
plugin_id=plugin_id,
component_name=component_name,
args=args,
)
return resp.payload
return await self._event_dispatcher.dispatch_event(
event_type=event_type,
invoke_fn=_invoke,
message=message,
extra_args=extra_args,
)
async def execute_workflow(
self,
message: dict[str, Any] | None = None,
stream_id: str | None = None,
context: WorkflowContext | None = None,
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
"""执行 Workflow Pipeline 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
resp = await self.invoke_plugin(
method="plugin.invoke_workflow_step",
plugin_id=plugin_id,
component_name=component_name,
args=args,
)
payload = resp.payload
if payload.get("success"):
result = payload.get("result")
return result if isinstance(result, dict) else {}
raise RuntimeError(payload.get("result", "workflow step invoke failed"))
return await self._workflow_executor.execute(
invoke_fn=_invoke,
message=message,
stream_id=stream_id,
context=context,
)
async def start(self) -> None: async def start(self) -> None:
"""启动 Supervisor """启动 Supervisor
@@ -137,11 +203,6 @@ class PluginSupervisor:
由主进程业务逻辑调用,通过 RPC 转发给 Runner。 由主进程业务逻辑调用,通过 RPC 转发给 Runner。
""" """
# 熔断检查
breaker = self._breakers.get(plugin_id)
if not breaker.allow_request():
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断")
try: try:
response = await self._rpc_server.send_request( response = await self._rpc_server.send_request(
method=method, method=method,
@@ -152,10 +213,8 @@ class PluginSupervisor:
}, },
timeout_ms=timeout_ms, timeout_ms=timeout_ms,
) )
breaker.record_success()
return response return response
except RPCError: except RPCError:
breaker.record_failure()
raise raise
async def reload_plugins(self, reason: str = "manual") -> None: async def reload_plugins(self, reason: str = "manual") -> None:
@@ -232,12 +291,20 @@ class PluginSupervisor:
self._policy.register_plugin( self._policy.register_plugin(
plugin_id=reg.plugin_id, plugin_id=reg.plugin_id,
generation=envelope.generation, generation=envelope.generation,
capabilities=reg.capabilities_required, capabilities=reg.capabilities_required or [],
) )
# 在 ComponentRegistry 中注册组件
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
)
stats = self._component_registry.get_stats()
logger.info( logger.info(
f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功," f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}" f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
f"注册表总计: {stats}"
) )
return envelope.make_response(payload={"accepted": True}) return envelope.make_response(payload={"accepted": True})
@@ -294,10 +361,32 @@ class PluginSupervisor:
await self._runner_process.wait() await self._runner_process.wait()
async def _health_check_loop(self) -> None: async def _health_check_loop(self) -> None:
"""周期性健康检查""" """周期性健康检查 + 崩溃自动重启"""
while self._running: while self._running:
await asyncio.sleep(self._health_interval) await asyncio.sleep(self._health_interval)
# 检查 Runner 进程是否意外退出
if self._runner_process and self._runner_process.returncode is not None:
exit_code = self._runner_process.returncode
logger.warning(f"Runner 进程已退出 (exit_code={exit_code})")
if self._restart_count < self._max_restart_attempts:
self._restart_count += 1
logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})")
# 清理旧的组件注册
for plugin_id in list(self._registered_plugins.keys()):
self._component_registry.remove_components_by_plugin(plugin_id)
self._policy.revoke_plugin(plugin_id)
self._registered_plugins.clear()
try:
await self._spawn_runner()
except Exception as e:
logger.error(f"Runner 重启失败: {e}", exc_info=True)
else:
logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启")
continue
if not self._rpc_server.is_connected: if not self._rpc_server.is_connected:
logger.warning("Runner 未连接,跳过健康检查") logger.warning("Runner 未连接,跳过健康检查")
continue continue
@@ -307,6 +396,9 @@ class PluginSupervisor:
health = HealthPayload.model_validate(resp.payload) health = HealthPayload.model_validate(resp.payload)
if not health.healthy: if not health.healthy:
logger.warning(f"Runner 健康检查异常: {health}") logger.warning(f"Runner 健康检查异常: {health}")
else:
# 健康检查成功,重置重启计数
self._restart_count = 0
except RPCError as e: except RPCError as e:
logger.error(f"健康检查失败: {e}") logger.error(f"健康检查失败: {e}")
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -0,0 +1,397 @@
"""Host-side WorkflowExecutor
6 阶段线性流转INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS
每个阶段执行顺序:
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
2. 按 priority 降序排列
3. 串行执行 blocking hook可修改 message返回 HookResult
4. 并发执行 non-blocking hook只读
5. 检查是否有 SKIP_STAGE 或 ABORT
6. PLAN 阶段内置 Command 匹配路由
支持:
- HookResult: CONTINUE / SKIP_STAGE / ABORT
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
- stage_outputs: 阶段间带命名空间的数据传递
- modification_log: 消息修改审计
"""
from typing import Any, Callable, Awaitable, Optional
import asyncio
import logging
import time
import uuid
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = logging.getLogger("plugin_runtime.host.workflow_executor")
# 阶段顺序
STAGE_SEQUENCE: list[str] = [
"ingress",
"pre_process",
"plan",
"tool_execute",
"post_process",
"egress",
]
# HookResult 常量(与 SDK HookResult enum 值对应)
HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]):
self.stage = stage
self.hook_name = hook_name
self.timestamp = time.perf_counter()
self.fields_changed = fields_changed
class WorkflowContext:
"""Workflow 执行上下文"""
def __init__(self, trace_id: str | None = None, stream_id: str | None = None):
self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id
self.timings: dict[str, float] = {}
self.errors: list[str] = []
# 阶段间数据传递(按 stage 命名空间隔离)
self.stage_outputs: dict[str, dict[str, Any]] = {}
# 消息修改审计日志
self.modification_log: list[ModificationRecord] = []
# PLAN 阶段命令匹配结果
self.matched_command: str | None = None
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
self.stage_outputs.setdefault(stage, {})[key] = value
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
return self.stage_outputs.get(stage, {}).get(key, default)
class WorkflowResult:
"""Workflow 执行结果"""
def __init__(
self,
status: str = "completed", # completed / aborted / failed
return_message: str = "",
stopped_at: str = "",
diagnostics: dict[str, Any] | None = None,
):
self.status = status
self.return_message = return_message
self.stopped_at = stopped_at
self.diagnostics = diagnostics or {}
# invoke_fn 签名
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
class WorkflowExecutor:
"""Host-side Workflow 执行器
实现 stage-based pipeline + per-stage hook chain with priority + early return。
"""
def __init__(self, registry: ComponentRegistry):
self._registry = registry
async def execute(
self,
invoke_fn: InvokeFn,
message: dict[str, Any] | None = None,
stream_id: str | None = None,
context: WorkflowContext | None = None,
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
"""执行 workflow pipeline。
Returns:
(result, final_message, context)
"""
ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None
for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter()
try:
# PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message:
cmd_result = await self._route_command(
invoke_fn, current_message, ctx
)
if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result)
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 获取该阶段所有 hook已按 priority 降序排列)
all_steps = self._registry.get_workflow_steps(stage)
if not all_steps:
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 1. Pre-filter
filtered_steps = self._pre_filter(all_steps, current_message)
# 2. 分离 blocking 和 non-blocking
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
# 3. 串行执行 blocking hook
skip_stage = False
for step in blocking_steps:
hook_result, modified, step_error = await self._invoke_step(
invoke_fn, step, stage, ctx, current_message
)
if step_error:
error_policy = step.metadata.get("error_policy", "abort")
ctx.errors.append(f"{step.full_name}: {step_error}")
if error_policy == "abort":
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="failed",
return_message=step_error,
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
elif error_policy == "skip":
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
continue
else: # log
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
continue
# 更新消息(仅 blocking hook 有权修改)
if modified:
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys())
ctx.modification_log.append(
ModificationRecord(stage, step.full_name, changed_fields)
)
current_message = modified
if hook_result == HOOK_ABORT:
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="aborted",
return_message=f"aborted by {step.full_name}",
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
if hook_result == HOOK_SKIP_STAGE:
skip_stage = True
break
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage:
nb_tasks = [
self._invoke_step_fire_and_forget(
invoke_fn, step, stage, ctx, current_message
)
for step in nonblocking_steps
]
# 并发执行但不阻塞 pipeline
for task in [asyncio.create_task(t) for t in nb_tasks]:
task.add_done_callback(lambda _: None)
ctx.timings[stage] = time.perf_counter() - stage_start
except Exception as e:
ctx.timings[stage] = time.perf_counter() - stage_start
ctx.errors.append(f"{stage}: {e}")
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
return (
WorkflowResult(
status="failed",
return_message=str(e),
stopped_at=stage,
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
return (
WorkflowResult(
status="completed",
return_message="workflow completed",
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
# ─── 内部方法 ──────────────────────────────────────────────
def _pre_filter(
self,
steps: list[RegisteredComponent],
message: dict[str, Any] | None,
) -> list[RegisteredComponent]:
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
if not message:
return steps
result = []
for step in steps:
filter_cond = step.metadata.get("filter", {})
if not filter_cond:
result.append(step)
continue
if self._match_filter(filter_cond, message):
result.append(step)
return result
@staticmethod
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool:
"""简单 key-value 匹配过滤。
filter 中的每个 key 必须在 message 中存在且值相等,
全部匹配才通过。
"""
for key, expected in filter_cond.items():
actual = message.get(key)
if isinstance(expected, list):
if actual not in expected:
return False
elif actual != expected:
return False
return True
async def _invoke_step(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: dict[str, Any] | None,
) -> tuple[str, dict[str, Any] | None, str | None]:
"""调用单个 blocking hook。
Returns:
(hook_result, modified_message, error_string_or_None)
"""
timeout_ms = step.metadata.get("timeout_ms", 0)
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
step_key = f"{stage}:{step.full_name}"
step_start = time.perf_counter()
try:
coro = invoke_fn(step.plugin_id, step.name, {
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
})
resp = await asyncio.wait_for(coro, timeout=timeout_sec) if timeout_sec else await coro
ctx.timings[step_key] = time.perf_counter() - step_start
hook_result = resp.get("hook_result", HOOK_CONTINUE)
modified_message = resp.get("modified_message")
# 存 stage output如果 hook 提供了)
stage_out = resp.get("stage_output")
if isinstance(stage_out, dict):
for k, v in stage_out.items():
ctx.set_stage_output(stage, k, v)
return hook_result, modified_message, None
except asyncio.TimeoutError:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
except Exception as e:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, str(e)
async def _invoke_step_fire_and_forget(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: dict[str, Any] | None,
) -> None:
"""Non-blocking hook 调用,只读,忽略结果。"""
timeout_ms = step.metadata.get("timeout_ms", 0)
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
try:
coro = invoke_fn(step.plugin_id, step.name, {
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
})
if timeout_sec:
await asyncio.wait_for(coro, timeout=timeout_sec)
else:
await coro
except Exception as e:
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
async def _route_command(
self,
invoke_fn: InvokeFn,
message: dict[str, Any],
ctx: WorkflowContext,
) -> dict[str, Any] | None:
"""PLAN 阶段内置 Command 路由。
在 registry 中查找匹配的 command 组件,
匹配到则直接路由到对应 command handler返回执行结果。
不匹配则返回 None让 PLAN 阶段的 hook 继续执行。
"""
plain_text = message.get("plain_text", "")
if not plain_text:
return None
matched = self._registry.find_command_by_text(plain_text)
if matched is None:
return None
ctx.matched_command = matched.full_name
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try:
resp = await invoke_fn(matched.plugin_id, matched.name, {
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
})
return resp
except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}")
return None
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
"""返回 new 中与 old 不同的 key 列表。"""
changed = []
for k in new:
if k not in old or old[k] != new[k]:
changed.append(k)
return changed

View File

@@ -1 +1 @@
# Protocol 层 - RPC 消息模型、编解码、错误码

View File

@@ -1,13 +1,7 @@
"""MsgPack / JSON 编解码器 """MsgPack 编解码器"""
提供统一的消息编解码接口,生产环境默认使用 MsgPack
开发调试模式可切换为 JSON仅编解码切换传输层不变
"""
from typing import Any from typing import Any
import json
import msgpack import msgpack
from .envelope import Envelope from .envelope import Envelope
@@ -30,7 +24,7 @@ class Codec:
class MsgPackCodec(Codec): class MsgPackCodec(Codec):
"""MsgPack 编解码器(生产默认)""" """MsgPack 编解码器"""
def encode(self, obj: dict[str, Any]) -> bytes: def encode(self, obj: dict[str, Any]) -> bytes:
return msgpack.packb(obj, use_bin_type=True) return msgpack.packb(obj, use_bin_type=True)
@@ -47,34 +41,3 @@ class MsgPackCodec(Codec):
def decode_envelope(self, data: bytes) -> Envelope: def decode_envelope(self, data: bytes) -> Envelope:
raw = self.decode(data) raw = self.decode(data)
return Envelope.model_validate(raw) return Envelope.model_validate(raw)
class JsonCodec(Codec):
"""JSON 编解码器(开发调试用)"""
def encode(self, obj: dict[str, Any]) -> bytes:
return json.dumps(obj, ensure_ascii=False).encode("utf-8")
def decode(self, data: bytes) -> dict[str, Any]:
result = json.loads(data.decode("utf-8"))
if not isinstance(result, dict):
raise ValueError(f"期望解码为 dict实际为 {type(result)}")
return result
def encode_envelope(self, envelope: Envelope) -> bytes:
return self.encode(envelope.model_dump())
def decode_envelope(self, data: bytes) -> Envelope:
raw = self.decode(data)
return Envelope.model_validate(raw)
def create_codec(use_json: bool = False) -> Codec:
"""创建编解码器实例
Args:
use_json: 是否使用 JSON开发模式。默认使用 MsgPack。
"""
if use_json:
return JsonCodec()
return MsgPackCodec()

View File

@@ -1 +1 @@
# Runner 端 - 插件加载与执行进程

View File

@@ -0,0 +1,137 @@
"""Manifest 校验与版本兼容性
从旧系统的 ManifestValidator / VersionComparator 对齐移植,
适配新 plugin_runtime 的 _manifest.json 格式。
"""
from typing import Any
import logging
import re
logger = logging.getLogger("plugin_runtime.runner.manifest_validator")
class VersionComparator:
"""语义化版本号比较器"""
@staticmethod
def normalize_version(version: str) -> str:
if not version:
return "0.0.0"
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
return "0.0.0"
parts = normalized.split(".")
while len(parts) < 3:
parts.append("0")
return ".".join(parts[:3])
@staticmethod
def parse_version(version: str) -> tuple[int, int, int]:
normalized = VersionComparator.normalize_version(version)
try:
parts = normalized.split(".")
return (int(parts[0]), int(parts[1]), int(parts[2]))
except (ValueError, IndexError):
return (0, 0, 0)
@staticmethod
def compare(v1: str, v2: str) -> int:
t1 = VersionComparator.parse_version(v1)
t2 = VersionComparator.parse_version(v2)
if t1 < t2:
return -1
elif t1 > t2:
return 1
return 0
@staticmethod
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]:
if not min_version and not max_version:
return True, ""
vn = VersionComparator.normalize_version(version)
if min_version:
mn = VersionComparator.normalize_version(min_version)
if VersionComparator.compare(vn, mn) < 0:
return False, f"版本 {vn} 低于最小要求 {mn}"
if max_version:
mx = VersionComparator.normalize_version(max_version)
if VersionComparator.compare(vn, mx) > 0:
return False, f"版本 {vn} 高于最大支持 {mx}"
return True, ""
class ManifestValidator:
"""_manifest.json 校验器"""
REQUIRED_FIELDS = ["name", "version", "description", "author"]
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
SUPPORTED_MANIFEST_VERSIONS = [1, 2]
def __init__(self, host_version: str = ""):
self._host_version = host_version
self.errors: list[str] = []
self.warnings: list[str] = []
def validate(self, manifest: dict[str, Any]) -> bool:
"""校验 manifest 数据返回是否通过errors 为空即通过)。"""
self.errors.clear()
self.warnings.clear()
self._check_required_fields(manifest)
self._check_manifest_version(manifest)
self._check_author(manifest)
self._check_host_compatibility(manifest)
self._check_recommended(manifest)
if self.errors:
for e in self.errors:
logger.error(f"Manifest 校验失败: {e}")
if self.warnings:
for w in self.warnings:
logger.warning(f"Manifest 警告: {w}")
return len(self.errors) == 0
def _check_required_fields(self, manifest: dict[str, Any]) -> None:
for field in self.REQUIRED_FIELDS:
if field not in manifest:
self.errors.append(f"缺少必需字段: {field}")
elif not manifest[field]:
self.errors.append(f"必需字段不能为空: {field}")
def _check_manifest_version(self, manifest: dict[str, Any]) -> None:
mv = manifest.get("manifest_version")
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
self.errors.append(
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
)
def _check_author(self, manifest: dict[str, Any]) -> None:
author = manifest.get("author")
if author is None:
return
if isinstance(author, dict):
if "name" not in author or not author["name"]:
self.errors.append("author 对象缺少 name 字段")
elif isinstance(author, str):
if not author.strip():
self.errors.append("author 不能为空")
else:
self.errors.append("author 应为字符串或 {name, url} 对象")
def _check_host_compatibility(self, manifest: dict[str, Any]) -> None:
host_app = manifest.get("host_application")
if not isinstance(host_app, dict) or not self._host_version:
return
min_v = host_app.get("min_version", "")
max_v = host_app.get("max_version", "")
ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v)
if not ok:
self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
def _check_recommended(self, manifest: dict[str, Any]) -> None:
for field in self.RECOMMENDED_FIELDS:
if field not in manifest or not manifest[field]:
self.warnings.append(f"建议填写字段: {field}")

View File

@@ -2,8 +2,10 @@
在 Runner 进程中负责发现和加载插件。 在 Runner 进程中负责发现和加载插件。
插件通过 SDK 编写,不再 import src.*。 插件通过 SDK 编写,不再 import src.*。
支持manifest 校验、依赖解析(拓扑排序)、生命周期钩子。
""" """
from collections import deque
from typing import Any from typing import Any
import importlib import importlib
@@ -13,6 +15,8 @@ import logging
import os import os
import sys import sys
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
logger = logging.getLogger("plugin_runtime.runner.plugin_loader") logger = logging.getLogger("plugin_runtime.runner.plugin_loader")
@@ -32,6 +36,20 @@ class PluginMeta:
self.manifest = manifest self.manifest = manifest
self.version = manifest.get("version", "1.0.0") self.version = manifest.get("version", "1.0.0")
self.capabilities_required = manifest.get("capabilities", []) self.capabilities_required = manifest.get("capabilities", [])
self.dependencies: list[str] = self._extract_dependencies(manifest)
@staticmethod
def _extract_dependencies(manifest: dict[str, Any]) -> list[str]:
raw = manifest.get("dependencies", [])
result: list[str] = []
for dep in raw:
if isinstance(dep, str):
result.append(dep.strip())
elif isinstance(dep, dict):
name = str(dep.get("name", "")).strip()
if name:
result.append(name)
return result
class PluginLoader: class PluginLoader:
@@ -43,19 +61,22 @@ class PluginLoader:
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数) - plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
""" """
def __init__(self): def __init__(self, host_version: str = ""):
self._loaded_plugins: dict[str, PluginMeta] = {} self._loaded_plugins: dict[str, PluginMeta] = {}
self._failed_plugins: dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version)
def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]: def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]:
"""扫描多个目录并加载所有插件 """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
Args: Args:
plugin_dirs: 插件目录列表 plugin_dirs: 插件目录列表
Returns: Returns:
成功加载的插件元数据列表 成功加载的插件元数据列表(按依赖顺序)
""" """
results = [] # 第一阶段:发现并校验 manifest
candidates: dict[str, tuple[str, dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path)
for base_dir in plugin_dirs: for base_dir in plugin_dirs:
if not os.path.isdir(base_dir): if not os.path.isdir(base_dir):
logger.warning(f"插件目录不存在: {base_dir}") logger.warning(f"插件目录不存在: {base_dir}")
@@ -73,12 +94,40 @@ class PluginLoader:
continue continue
try: try:
meta = self._load_single_plugin(plugin_dir, manifest_path, plugin_path) with open(manifest_path, "r", encoding="utf-8") as f:
if meta: manifest = json.load(f)
self._loaded_plugins[meta.plugin_id] = meta
results.append(meta)
except Exception as e: except Exception as e:
logger.error(f"加载插件失败 [{plugin_dir}]: {e}", exc_info=True) self._failed_plugins[entry] = f"manifest 解析失败: {e}"
logger.error(f"插件 {entry} manifest 解析失败: {e}")
continue
if not self._manifest_validator.validate(manifest):
errors = "; ".join(self._manifest_validator.errors)
self._failed_plugins[entry] = f"manifest 校验失败: {errors}"
continue
plugin_id = manifest.get("name", entry)
candidates[plugin_id] = (plugin_dir, manifest, plugin_path)
# 第二阶段:依赖解析(拓扑排序)
load_order, failed_deps = self._resolve_dependencies(candidates)
for pid, reason in failed_deps.items():
self._failed_plugins[pid] = reason
logger.error(f"插件 {pid} 依赖解析失败: {reason}")
# 第三阶段:按依赖顺序加载
results = []
for plugin_id in load_order:
plugin_dir, manifest, plugin_path = candidates[plugin_id]
try:
meta = self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
if meta:
self._loaded_plugins[meta.plugin_id] = meta
results.append(meta)
except Exception as e:
self._failed_plugins[plugin_id] = str(e)
logger.error(f"加载插件失败 [{plugin_id}]: {e}", exc_info=True)
return results return results
@@ -90,15 +139,78 @@ class PluginLoader:
"""列出所有已加载的插件 ID""" """列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys()) return list(self._loaded_plugins.keys())
def _load_single_plugin(self, plugin_dir: str, manifest_path: str, plugin_path: str) -> PluginMeta | None: @property
def failed_plugins(self) -> dict[str, str]:
return dict(self._failed_plugins)
# ──── 依赖解析 ────────────────────────────────────────────
def _resolve_dependencies(
self,
candidates: dict[str, tuple[str, dict[str, Any], str]],
) -> tuple[list[str], dict[str, str]]:
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
available = set(candidates.keys())
dep_graph: dict[str, set[str]] = {}
failed: dict[str, str] = {}
for pid, (_, manifest, _) in candidates.items():
raw_deps = manifest.get("dependencies", [])
resolved: set[str] = set()
missing: list[str] = []
for dep in raw_deps:
dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
dep_name = dep_name.strip()
if not dep_name or dep_name == pid:
continue
if dep_name in available:
resolved.add(dep_name)
else:
missing.append(dep_name)
if missing:
failed[pid] = f"缺少依赖: {', '.join(missing)}"
dep_graph[pid] = resolved
# 移除失败项
for pid in failed:
dep_graph.pop(pid, None)
# Kahn 拓扑排序
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
reverse: dict[str, set[str]] = {pid: set() for pid in dep_graph}
for pid, deps in dep_graph.items():
for d in deps:
if d in reverse:
reverse[d].add(pid)
queue = deque(sorted(pid for pid, deg in indegree.items() if deg == 0))
sorted_order: list[str] = []
while queue:
current = queue.popleft()
sorted_order.append(current)
for dependent in sorted(reverse.get(current, [])):
indegree[dependent] -= 1
if indegree[dependent] == 0:
queue.append(dependent)
cycle_plugins = {pid for pid, deg in indegree.items() if deg > 0}
for pid in cycle_plugins:
failed[pid] = "检测到循环依赖"
return sorted_order, failed
# ──── 单个插件加载 ────────────────────────────────────────
def _load_single_plugin(
self,
plugin_id: str,
plugin_dir: str,
manifest: dict[str, Any],
plugin_path: str,
) -> PluginMeta | None:
"""加载单个插件""" """加载单个插件"""
# 1. 读取 manifest # 动态导入插件模块
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
plugin_id = os.path.basename(plugin_dir)
# 2. 动态导入插件模块
module_name = f"_maibot_plugin_{plugin_id}" module_name = f"_maibot_plugin_{plugin_id}"
spec = importlib.util.spec_from_file_location(module_name, plugin_path) spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None: if spec is None or spec.loader is None:
@@ -109,7 +221,7 @@ class PluginLoader:
sys.modules[module_name] = module sys.modules[module_name] = module
spec.loader.exec_module(module) spec.loader.exec_module(module)
# 3. 调用工厂函数创建插件实例 # 调用工厂函数创建插件实例
create_plugin = getattr(module, "create_plugin", None) create_plugin = getattr(module, "create_plugin", None)
if create_plugin is None: if create_plugin is None:
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数") logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数")

View File

@@ -14,7 +14,7 @@ import asyncio
import logging import logging
import uuid import uuid
from src.plugin_runtime.protocol.codec import Codec, create_codec from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import (
PROTOCOL_VERSION, PROTOCOL_VERSION,
Envelope, Envelope,
@@ -49,7 +49,7 @@ class RPCClient:
): ):
self._host_address = host_address self._host_address = host_address
self._session_token = session_token self._session_token = session_token
self._codec = codec or create_codec() self._codec = codec or MsgPackCodec()
self._id_gen = RequestIdGenerator() self._id_gen = RequestIdGenerator()
self._connection: Connection | None = None self._connection: Connection | None = None

View File

@@ -71,7 +71,23 @@ class PluginRunner:
plugins = self._loader.discover_and_load(self._plugin_dirs) plugins = self._loader.discover_and_load(self._plugin_dirs)
logger.info(f"已加载 {len(plugins)} 个插件") logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 向 Host 注册所有插件的组件 # 4. 调用 on_load 生命周期钩子 + 注入 RPC 客户端供 SDK context 使用
for meta in plugins:
instance = meta.instance
# 注入 _rpc_client 以便 PluginContext 可以发起能力调用
if hasattr(instance, "_ctx"):
ctx = instance._ctx
if hasattr(ctx, "_set_rpc_client"):
ctx._set_rpc_client(self._rpc_client)
if hasattr(instance, "on_load"):
try:
ret = instance.on_load()
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {meta.plugin_id} on_load 失败: {e}", exc_info=True)
# 5. 向 Host 注册所有插件的组件
for meta in plugins: for meta in plugins:
await self._register_plugin(meta) await self._register_plugin(meta)
@@ -92,6 +108,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
@@ -169,6 +186,57 @@ class PluginRunner:
resp_payload = InvokeResultPayload(success=False, result=str(e)) resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump()) return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
与通用 invoke 不同,会将返回值规范化为
{hook_result, modified_message, stage_output} 格式。
"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
instance = meta.instance
component_name = invoke.component_name
handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
# 规范化返回值
if raw is None:
result = {"hook_result": "continue"}
elif isinstance(raw, str):
# 允许直接返回 hook_result 字符串
result = {"hook_result": raw}
elif isinstance(raw, dict):
result = raw
result.setdefault("hook_result", "continue")
else:
result = {"hook_result": "continue"}
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_health(self, envelope: Envelope) -> Envelope: async def _handle_health(self, envelope: Envelope) -> Envelope:
"""处理健康检查""" """处理健康检查"""
uptime_ms = int((time.monotonic() - self._start_time) * 1000) uptime_ms = int((time.monotonic() - self._start_time) * 1000)
@@ -185,8 +253,17 @@ class PluginRunner:
return envelope.make_response(payload={"acknowledged": True}) return envelope.make_response(payload={"acknowledged": True})
async def _handle_shutdown(self, envelope: Envelope) -> Envelope: async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
"""处理关停""" """处理关停 — 调用所有插件的 on_unload 后退出"""
logger.info("收到 shutdown 信号,准备退出") logger.info("收到 shutdown 信号,开始调用 on_unload")
for plugin_id in self._loader.list_plugins():
meta = self._loader.get_plugin(plugin_id)
if meta and hasattr(meta.instance, "on_unload"):
try:
ret = meta.instance.on_unload()
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
self._shutting_down = True self._shutting_down = True
return envelope.make_response(payload={"acknowledged": True}) return envelope.make_response(payload={"acknowledged": True})
@@ -209,6 +286,43 @@ class PluginRunner:
return self._rpc_client return self._rpc_client
# ─── sys.path 隔离 ────────────────────────────────────────
def _isolate_sys_path(plugin_dirs: list[str]) -> None:
"""清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。
防止插件代码 import 主程序模块读取运行时数据。
"""
import sysconfig
# 保留: 标准库路径 + site-packages含 SDK 和依赖)
stdlib_paths = set()
for key in ("stdlib", "platstdlib", "purelib", "platlib"):
path = sysconfig.get_path(key)
if path:
stdlib_paths.add(os.path.normpath(path))
allowed = set()
for p in sys.path:
norm = os.path.normpath(p)
# 保留标准库和 site-packages
if any(norm.startswith(sp) for sp in stdlib_paths):
allowed.add(p)
# 保留 site-packages第三方库 + SDK
if "site-packages" in norm or "dist-packages" in norm:
allowed.add(p)
# 添加插件目录
for d in plugin_dirs:
allowed.add(os.path.normpath(d))
# 添加当前 runner 模块所在路径(使得 src.plugin_runtime 可导入)
runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
allowed.add(runtime_root)
sys.path[:] = [p for p in sys.path if p in allowed]
# ─── 进程入口 ────────────────────────────────────────────── # ─── 进程入口 ──────────────────────────────────────────────
async def _async_main() -> None: async def _async_main() -> None:
@@ -223,6 +337,9 @@ async def _async_main() -> None:
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
# sys.path 隔离: 只保留标准库、SDK 包、插件目录
_isolate_sys_path(plugin_dirs)
runner = PluginRunner(host_address, session_token, plugin_dirs) runner = PluginRunner(host_address, session_token, plugin_dirs)
# 注册信号处理 # 注册信号处理

View File

@@ -1 +1 @@
# Transport 层 - 跨平台本地 IPC 传输抽象