feat: 添加插件身份绑定机制,防止伪造插件身份的 RPC 调用
This commit is contained in:
@@ -313,6 +313,43 @@ class TestSDK:
|
|||||||
assert plugin.ctx.llm is not None
|
assert plugin.ctx.llm is not None
|
||||||
assert plugin.ctx.config is not None
|
assert plugin.ctx.config is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_injected_context_binds_plugin_identity(self):
|
||||||
|
"""Runner 注入的上下文应忽略调用方伪造的 plugin_id。"""
|
||||||
|
from src.plugin_runtime.runner.runner_main import PluginRunner
|
||||||
|
|
||||||
|
class DummyRPCClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000):
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"method": method,
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"payload": payload,
|
||||||
|
"timeout_ms": timeout_ms,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return SimpleNamespace(error=None, payload={"result": {"ok": True}})
|
||||||
|
|
||||||
|
class DummyPlugin:
|
||||||
|
def _set_context(self, ctx):
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
|
runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[])
|
||||||
|
runner._rpc_client = DummyRPCClient()
|
||||||
|
|
||||||
|
plugin = DummyPlugin()
|
||||||
|
runner._inject_context("owner_plugin", plugin)
|
||||||
|
|
||||||
|
plugin.ctx._plugin_id = "forged_plugin"
|
||||||
|
result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1")
|
||||||
|
|
||||||
|
assert result == {"ok": True}
|
||||||
|
assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin"
|
||||||
|
assert runner._rpc_client.calls[0]["method"] == "cap.request"
|
||||||
|
|
||||||
|
|
||||||
# ─── 端到端集成测试 ────────────────────────────────────────
|
# ─── 端到端集成测试 ────────────────────────────────────────
|
||||||
|
|
||||||
@@ -1177,6 +1214,46 @@ class TestWorkflowExecutor:
|
|||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nonblocking_tasks_are_retained_until_completion(self):
|
||||||
|
"""execute 返回后,non-blocking task 仍应保持强引用直到执行完成。"""
|
||||||
|
from src.plugin_runtime.host.component_registry import ComponentRegistry
|
||||||
|
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
|
||||||
|
|
||||||
|
reg = ComponentRegistry()
|
||||||
|
reg.register_component(
|
||||||
|
"observer",
|
||||||
|
"workflow_step",
|
||||||
|
"p1",
|
||||||
|
{
|
||||||
|
"stage": "post_process",
|
||||||
|
"priority": 0,
|
||||||
|
"blocking": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
executor = WorkflowExecutor(reg)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def mock_invoke(plugin_id, comp_name, args):
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
return {"hook_result": "continue"}
|
||||||
|
|
||||||
|
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert final_msg["plain_text"] == "original"
|
||||||
|
assert started.is_set()
|
||||||
|
assert len(executor._background_tasks) == 1
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert not executor._background_tasks
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_command_routing(self):
|
async def test_command_routing(self):
|
||||||
"""PLAN 阶段内置命令路由"""
|
"""PLAN 阶段内置命令路由"""
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
- modification_log: 消息修改审计
|
- modification_log: 消息修改审计
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -113,6 +113,7 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
def __init__(self, registry: ComponentRegistry) -> None:
|
def __init__(self, registry: ComponentRegistry) -> None:
|
||||||
self._registry = registry
|
self._registry = registry
|
||||||
|
self._background_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
@@ -134,8 +135,6 @@ class WorkflowExecutor:
|
|||||||
"""
|
"""
|
||||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||||
current_message = dict(message) if message else None
|
current_message = dict(message) if message else None
|
||||||
# 保持非阻塞任务引用,防止被 GC 回收
|
|
||||||
background_tasks: List[asyncio.Task] = []
|
|
||||||
|
|
||||||
for stage in STAGE_SEQUENCE:
|
for stage in STAGE_SEQUENCE:
|
||||||
stage_start = time.perf_counter()
|
stage_start = time.perf_counter()
|
||||||
@@ -220,14 +219,12 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||||
if nonblocking_steps and not skip_stage:
|
if nonblocking_steps and not skip_stage:
|
||||||
nb_tasks = [
|
for step in nonblocking_steps:
|
||||||
asyncio.create_task(
|
self._track_background_task(
|
||||||
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
asyncio.create_task(
|
||||||
|
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for step in nonblocking_steps
|
|
||||||
]
|
|
||||||
# 保持任务引用以防止被 GC 回收
|
|
||||||
background_tasks.extend(nb_tasks)
|
|
||||||
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||||
|
|
||||||
@@ -256,6 +253,11 @@ class WorkflowExecutor:
|
|||||||
ctx,
|
ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _track_background_task(self, task: asyncio.Task) -> None:
|
||||||
|
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
# ─── 内部方法 ──────────────────────────────────────────────
|
# ─── 内部方法 ──────────────────────────────────────────────
|
||||||
|
|
||||||
def _pre_filter(
|
def _pre_filter(
|
||||||
|
|||||||
@@ -186,12 +186,21 @@ class PluginRunner:
|
|||||||
return
|
return
|
||||||
|
|
||||||
rpc_client = self._rpc_client
|
rpc_client = self._rpc_client
|
||||||
|
bound_plugin_id = plugin_id
|
||||||
|
|
||||||
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
|
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
|
||||||
"""桥接 PluginContext.call_capability → RPCClient.send_request"""
|
"""桥接 PluginContext.call_capability → RPCClient.send_request。
|
||||||
|
|
||||||
|
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
|
||||||
|
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
|
||||||
|
"""
|
||||||
|
if plugin_id and plugin_id != bound_plugin_id:
|
||||||
|
logger.warning(
|
||||||
|
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
|
||||||
|
)
|
||||||
resp = await rpc_client.send_request(
|
resp = await rpc_client.send_request(
|
||||||
method=method,
|
method=method,
|
||||||
plugin_id=plugin_id,
|
plugin_id=bound_plugin_id,
|
||||||
payload=payload or {},
|
payload=payload or {},
|
||||||
)
|
)
|
||||||
# 从响应信封中提取业务结果
|
# 从响应信封中提取业务结果
|
||||||
|
|||||||
Reference in New Issue
Block a user