feat: 添加插件身份绑定机制,防止伪造插件身份的 RPC 调用

This commit is contained in:
DrSmoothl
2026-03-13 15:40:14 +08:00
parent 44a9e9ecd7
commit f1e10b4054
3 changed files with 100 additions and 12 deletions

View File

@@ -313,6 +313,43 @@ class TestSDK:
assert plugin.ctx.llm is not None assert plugin.ctx.llm is not None
assert plugin.ctx.config is not None assert plugin.ctx.config is not None
@pytest.mark.asyncio
async def test_runner_injected_context_binds_plugin_identity(self):
"""Runner 注入的上下文应忽略调用方伪造的 plugin_id。"""
from src.plugin_runtime.runner.runner_main import PluginRunner
class DummyRPCClient:
def __init__(self):
self.calls = []
async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000):
self.calls.append(
{
"method": method,
"plugin_id": plugin_id,
"payload": payload,
"timeout_ms": timeout_ms,
}
)
return SimpleNamespace(error=None, payload={"result": {"ok": True}})
class DummyPlugin:
def _set_context(self, ctx):
self.ctx = ctx
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
runner._rpc_client = DummyRPCClient()
plugin = DummyPlugin()
runner._inject_context("owner_plugin", plugin)
plugin.ctx._plugin_id = "forged_plugin"
result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1")
assert result == {"ok": True}
assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin"
assert runner._rpc_client.calls[0]["method"] == "cap.request"
# ─── 端到端集成测试 ──────────────────────────────────────── # ─── 端到端集成测试 ────────────────────────────────────────
@@ -1177,6 +1214,46 @@ class TestWorkflowExecutor:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert result.status == "completed" assert result.status == "completed"
@pytest.mark.asyncio
async def test_nonblocking_tasks_are_retained_until_completion(self):
"""execute 返回后non-blocking task 仍应保持强引用直到执行完成。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"observer",
"workflow_step",
"p1",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg)
started = asyncio.Event()
release = asyncio.Event()
async def mock_invoke(plugin_id, comp_name, args):
started.set()
await release.wait()
return {"hook_result": "continue"}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
await asyncio.sleep(0)
assert result.status == "completed"
assert final_msg["plain_text"] == "original"
assert started.is_set()
assert len(executor._background_tasks) == 1
release.set()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert not executor._background_tasks
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_command_routing(self): async def test_command_routing(self):
"""PLAN 阶段内置命令路由""" """PLAN 阶段内置命令路由"""

View File

@@ -17,7 +17,7 @@
- modification_log: 消息修改审计 - modification_log: 消息修改审计
""" """
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import asyncio import asyncio
import time import time
@@ -113,6 +113,7 @@ class WorkflowExecutor:
def __init__(self, registry: ComponentRegistry) -> None: def __init__(self, registry: ComponentRegistry) -> None:
self._registry = registry self._registry = registry
self._background_tasks: Set[asyncio.Task] = set()
async def execute( async def execute(
self, self,
@@ -134,8 +135,6 @@ class WorkflowExecutor:
""" """
ctx = context or WorkflowContext(stream_id=stream_id) ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None current_message = dict(message) if message else None
# 保持非阻塞任务引用,防止被 GC 回收
background_tasks: List[asyncio.Task] = []
for stage in STAGE_SEQUENCE: for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter() stage_start = time.perf_counter()
@@ -220,14 +219,12 @@ class WorkflowExecutor:
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message # 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage: if nonblocking_steps and not skip_stage:
nb_tasks = [ for step in nonblocking_steps:
asyncio.create_task( self._track_background_task(
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message) asyncio.create_task(
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
)
) )
for step in nonblocking_steps
]
# 保持任务引用以防止被 GC 回收
background_tasks.extend(nb_tasks)
ctx.timings[stage] = time.perf_counter() - stage_start ctx.timings[stage] = time.perf_counter() - stage_start
@@ -256,6 +253,11 @@ class WorkflowExecutor:
ctx, ctx,
) )
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# ─── 内部方法 ────────────────────────────────────────────── # ─── 内部方法 ──────────────────────────────────────────────
def _pre_filter( def _pre_filter(

View File

@@ -186,12 +186,21 @@ class PluginRunner:
return return
rpc_client = self._rpc_client rpc_client = self._rpc_client
bound_plugin_id = plugin_id
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any: async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
"""桥接 PluginContext.call_capability → RPCClient.send_request""" """桥接 PluginContext.call_capability → RPCClient.send_request
无论调用方传入何种 plugin_id实际发往 Host 的 plugin_id
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
"""
if plugin_id and plugin_id != bound_plugin_id:
logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份"
)
resp = await rpc_client.send_request( resp = await rpc_client.send_request(
method=method, method=method,
plugin_id=plugin_id, plugin_id=bound_plugin_id,
payload=payload or {}, payload=payload or {},
) )
# 从响应信封中提取业务结果 # 从响应信封中提取业务结果