diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index ab4835a4..2688f1c4 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -886,7 +886,7 @@ class TestWorkflowExecutor: await executor.execute( mock_invoke, message={"plain_text": "hi", "chat_type": "group"} ) - assert call_log == [] + assert not call_log # 匹配 filter —— hook 应被调用 await executor.execute( diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 359dad41..a4793020 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -42,8 +42,7 @@ class RegisteredComponent: # 预编译命令正则(仅 command 类型) self._compiled_pattern: re.Pattern | None = None if component_type == "command": - pattern = metadata.get("command_pattern", "") - if pattern: + if pattern := metadata.get("command_pattern", ""): try: self._compiled_pattern = re.compile(pattern) except re.error as e: @@ -120,8 +119,7 @@ class ComponentRegistry: comps = self._by_plugin.pop(plugin_id, []) for comp in comps: self._components.pop(comp.full_name, None) - type_dict = self._by_type.get(comp.component_type) - if type_dict: + if type_dict := self._by_type.get(comp.component_type): type_dict.pop(comp.full_name, None) return len(comps) @@ -162,9 +160,7 @@ class ComponentRegistry: ) -> list[RegisteredComponent]: """按插件查询。""" comps = self._by_plugin.get(plugin_id, []) - if enabled_only: - return [c for c in comps if c.enabled] - return list(comps) + return [c for c in comps if c.enabled] if enabled_only else list(comps) def find_command_by_text(self, text: str) -> RegisteredComponent | None: """通过文本匹配命令正则,返回第一个匹配的 command 组件。""" diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 7ec28a74..0f42f0fc 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -93,7 +93,7 @@ class RPCServer: self._running = False # 取消所有 pending 请求 - for req_id, future in self._pending_requests.items(): + for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) self._pending_requests.clear() @@ -162,16 +162,15 @@ class RPCServer: # 等待响应 timeout_sec = timeout_ms / 1000.0 - response = await asyncio.wait_for(future, timeout=timeout_sec) - return response + return await asyncio.wait_for(future, timeout=timeout_sec) except asyncio.TimeoutError: self._pending_requests.pop(request_id, None) - raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") + raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None except Exception as e: self._pending_requests.pop(request_id, None) if isinstance(e, RPCError): raise - raise RPCError(ErrorCode.E_UNKNOWN, str(e)) + raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None: """向 Runner 发送单向事件(不等待响应)""" @@ -338,8 +337,7 @@ class RPCServer: async def _handle_event(self, envelope: Envelope) -> None: """处理来自 Runner 的事件""" - handler = self._method_handlers.get(envelope.method) - if handler: + if handler := self._method_handlers.get(envelope.method): try: await handler(envelope) except Exception as e: diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 339d9bbc..041b3a68 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -203,19 +203,15 @@ class PluginSupervisor: 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 """ - try: - response = await self._rpc_server.send_request( - method=method, - plugin_id=plugin_id, - payload={ - "component_name": component_name, - "args": args or {}, - }, - timeout_ms=timeout_ms, - ) - return response - except RPCError: - raise + return await self._rpc_server.send_request( + method=method, + plugin_id=plugin_id, + payload={ + "component_name": component_name, + "args": args or {}, + }, + timeout_ms=timeout_ms, + ) async def reload_plugins(self, reason: str = "manual") -> None: """热重载所有插件(进程级 generation 切换) diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py index 8e2937db..01eb3f3b 100644 --- a/src/plugin_runtime/host/workflow_executor.py +++ b/src/plugin_runtime/host/workflow_executor.py @@ -17,7 +17,7 @@ - modification_log: 消息修改审计 """ -from typing import Any, Callable, Awaitable, Optional +from typing import Any, Awaitable, Callable import asyncio import logging @@ -273,10 +273,9 @@ class WorkflowExecutor: """ for key, expected in filter_cond.items(): actual = message.get(key) - if isinstance(expected, list): - if actual not in expected: - return False - elif actual != expected: + if (isinstance(expected, list) and actual not in expected) or ( + not isinstance(expected, list) and actual != expected + ): return False return True @@ -376,12 +375,11 @@ class WorkflowExecutor: logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") try: - resp = await invoke_fn(matched.plugin_id, matched.name, { + return await invoke_fn(matched.plugin_id, matched.name, { "text": plain_text, "message": message, "trace_id": ctx.trace_id, }) - return resp except Exception as e: logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) ctx.errors.append(f"command:{matched.full_name}: {e}") @@ -390,8 +388,4 @@ class WorkflowExecutor: def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]: """返回 new 中与 old 不同的 key 列表。""" - changed = [] - for k in new: - if k not in old or old[k] != new[k]: - changed.append(k) - return changed + return [k for k, v in new.items() if k not in old or old[k] != v] diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 59775e7f..26d8f16f 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -46,8 +46,7 @@ class PluginMeta: if isinstance(dep, str): result.append(dep.strip()) elif isinstance(dep, dict): - name = str(dep.get("name", "")).strip() - if name: + if name := str(dep.get("name", "")).strip(): result.append(name) return result @@ -121,8 +120,7 @@ class PluginLoader: for plugin_id in load_order: plugin_dir, manifest, plugin_path = candidates[plugin_id] try: - meta = self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path) - if meta: + if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path): self._loaded_plugins[meta.plugin_id] = meta results.append(meta) except Exception as e: diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index b2052171..0e525e34 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -11,12 +11,12 @@ from typing import Any, Callable, Awaitable import asyncio +import contextlib import logging import uuid from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec from src.plugin_runtime.protocol.envelope import ( - PROTOCOL_VERSION, Envelope, HelloPayload, HelloResponsePayload, @@ -129,10 +129,8 @@ class RPCClient: self._running = False if self._recv_task: self._recv_task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await self._recv_task - except asyncio.CancelledError: - pass self._recv_task = None # 取消所有 pending 请求 @@ -176,16 +174,15 @@ class RPCClient: await self._connection.send_frame(data) timeout_sec = timeout_ms / 1000.0 - response = await asyncio.wait_for(future, timeout=timeout_sec) - return response + return await asyncio.wait_for(future, timeout=timeout_sec) except asyncio.TimeoutError: self._pending_requests.pop(request_id, None) - raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") + raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None except Exception as e: self._pending_requests.pop(request_id, None) if isinstance(e, RPCError): raise - raise RPCError(ErrorCode.E_UNKNOWN, str(e)) + raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e # ─── 内部方法 ────────────────────────────────────────────── @@ -249,8 +246,7 @@ class RPCClient: async def _handle_event(self, envelope: Envelope) -> None: """处理来自 Host 的事件""" - handler = self._method_handlers.get(envelope.method) - if handler: + if handler := self._method_handlers.get(envelope.method): try: await handler(envelope) except Exception as e: diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 61a90fcd..5c458bf2 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -12,6 +12,7 @@ from typing import Any import asyncio +import contextlib import logging import os import signal @@ -92,11 +93,9 @@ class PluginRunner: await self._register_plugin(meta) # 5. 等待直到收到关停信号 - try: + with contextlib.suppress(asyncio.CancelledError): while not self._shutting_down: await asyncio.sleep(1.0) - except asyncio.CancelledError: - pass # 6. 断开连接 await self._rpc_client.disconnect() @@ -122,13 +121,15 @@ class PluginRunner: # 从插件实例获取组件声明(SDK 插件须实现 get_components 方法) if hasattr(instance, "get_components"): - for comp_info in instance.get_components(): - components.append(ComponentDeclaration( + components.extend( + ComponentDeclaration( name=comp_info.get("name", ""), component_type=comp_info.get("type", ""), plugin_id=meta.plugin_id, metadata=comp_info.get("metadata", {}), - )) + ) + for comp_info in instance.get_components() + ) reg_payload = RegisterComponentsPayload( plugin_id=meta.plugin_id, @@ -219,10 +220,7 @@ class PluginRunner: raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args) # 规范化返回值 - if raw is None: - result = {"hook_result": "continue"} - elif isinstance(raw, str): - # 允许直接返回 hook_result 字符串 + if isinstance(raw, str): result = {"hook_result": raw} elif isinstance(raw, dict): result = raw @@ -298,8 +296,7 @@ def _isolate_sys_path(plugin_dirs: list[str]) -> None: # 保留: 标准库路径 + site-packages(含 SDK 和依赖) stdlib_paths = set() for key in ("stdlib", "platstdlib", "purelib", "platlib"): - path = sysconfig.get_path(key) - if path: + if path := sysconfig.get_path(key): stdlib_paths.add(os.path.normpath(path)) allowed = set()