From f1e10b4054c422cf173ba4dfd76fabb30e5a218e Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 13 Mar 2026 15:40:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E8=BA=AB=E4=BB=BD=E7=BB=91=E5=AE=9A=E6=9C=BA=E5=88=B6=EF=BC=8C?= =?UTF-8?q?=E9=98=B2=E6=AD=A2=E4=BC=AA=E9=80=A0=E6=8F=92=E4=BB=B6=E8=BA=AB?= =?UTF-8?q?=E4=BB=BD=E7=9A=84=20RPC=20=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_plugin_runtime.py | 77 ++++++++++++++++++++ src/plugin_runtime/host/workflow_executor.py | 22 +++--- src/plugin_runtime/runner/runner_main.py | 13 +++- 3 files changed, 100 insertions(+), 12 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index db210bc6..b00c6142 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -313,6 +313,43 @@ class TestSDK: assert plugin.ctx.llm is not None assert plugin.ctx.config is not None + @pytest.mark.asyncio + async def test_runner_injected_context_binds_plugin_identity(self): + """Runner 注入的上下文应忽略调用方伪造的 plugin_id。""" + from src.plugin_runtime.runner.runner_main import PluginRunner + + class DummyRPCClient: + def __init__(self): + self.calls = [] + + async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): + self.calls.append( + { + "method": method, + "plugin_id": plugin_id, + "payload": payload, + "timeout_ms": timeout_ms, + } + ) + return SimpleNamespace(error=None, payload={"result": {"ok": True}}) + + class DummyPlugin: + def _set_context(self, ctx): + self.ctx = ctx + + runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) + runner._rpc_client = DummyRPCClient() + + plugin = DummyPlugin() + runner._inject_context("owner_plugin", plugin) + + plugin.ctx._plugin_id = "forged_plugin" + result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1") + + assert result == {"ok": True} + assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin" + assert runner._rpc_client.calls[0]["method"] == "cap.request" + # ─── 端到端集成测试 ──────────────────────────────────────── @@ -1177,6 +1214,46 @@ class TestWorkflowExecutor: await asyncio.sleep(0.1) assert result.status == "completed" + @pytest.mark.asyncio + async def test_nonblocking_tasks_are_retained_until_completion(self): + """execute 返回后,non-blocking task 仍应保持强引用直到执行完成。""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component( + "observer", + "workflow_step", + "p1", + { + "stage": "post_process", + "priority": 0, + "blocking": False, + }, + ) + executor = WorkflowExecutor(reg) + + started = asyncio.Event() + release = asyncio.Event() + + async def mock_invoke(plugin_id, comp_name, args): + started.set() + await release.wait() + return {"hook_result": "continue"} + + result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) + + await asyncio.sleep(0) + assert result.status == "completed" + assert final_msg["plain_text"] == "original" + assert started.is_set() + assert len(executor._background_tasks) == 1 + + release.set() + await asyncio.sleep(0) + await asyncio.sleep(0) + assert not executor._background_tasks + @pytest.mark.asyncio async def test_command_routing(self): """PLAN 阶段内置命令路由""" diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py index 732a3888..3037e9dd 100644 --- a/src/plugin_runtime/host/workflow_executor.py +++ b/src/plugin_runtime/host/workflow_executor.py @@ -17,7 +17,7 @@ - modification_log: 消息修改审计 """ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple import asyncio import time @@ -113,6 +113,7 @@ class WorkflowExecutor: def __init__(self, registry: ComponentRegistry) -> None: self._registry = registry + self._background_tasks: Set[asyncio.Task] = set() async def execute( self, @@ -134,8 +135,6 @@ class WorkflowExecutor: """ ctx = context or WorkflowContext(stream_id=stream_id) current_message = dict(message) if message else None - # 保持非阻塞任务引用,防止被 GC 回收 - background_tasks: List[asyncio.Task] = [] for stage in STAGE_SEQUENCE: stage_start = time.perf_counter() @@ -220,14 +219,12 @@ class WorkflowExecutor: # 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message) if nonblocking_steps and not skip_stage: - nb_tasks = [ - asyncio.create_task( - self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message) + for step in nonblocking_steps: + self._track_background_task( + asyncio.create_task( + self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message) + ) ) - for step in nonblocking_steps - ] - # 保持任务引用以防止被 GC 回收 - background_tasks.extend(nb_tasks) ctx.timings[stage] = time.perf_counter() - stage_start @@ -256,6 +253,11 @@ class WorkflowExecutor: ctx, ) + def _track_background_task(self, task: asyncio.Task) -> None: + """保持 non-blocking workflow task 的强引用,直到任务结束。""" + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + # ─── 内部方法 ────────────────────────────────────────────── def _pre_filter( diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index f881ab25..4873bede 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -186,12 +186,21 @@ class PluginRunner: return rpc_client = self._rpc_client + bound_plugin_id = plugin_id async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any: - """桥接 PluginContext.call_capability → RPCClient.send_request""" + """桥接 PluginContext.call_capability → RPCClient.send_request。 + + 无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id + 始终绑定为当前插件实例,避免伪造其他插件身份申请能力。 + """ + if plugin_id and plugin_id != bound_plugin_id: + logger.warning( + f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" + ) resp = await rpc_client.send_request( method=method, - plugin_id=plugin_id, + plugin_id=bound_plugin_id, payload=payload or {}, ) # 从响应信封中提取业务结果