refactor: 优化代码结构,简化条件判断和异常处理

This commit is contained in:
DrSmoothl
2026-03-06 12:00:00 +08:00
parent 2f21cd00bc
commit 1cd366bc09
8 changed files with 41 additions and 66 deletions

View File

@@ -886,7 +886,7 @@ class TestWorkflowExecutor:
await executor.execute( await executor.execute(
mock_invoke, message={"plain_text": "hi", "chat_type": "group"} mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
) )
assert call_log == [] assert not call_log
# 匹配 filter —— hook 应被调用 # 匹配 filter —— hook 应被调用
await executor.execute( await executor.execute(

View File

@@ -42,8 +42,7 @@ class RegisteredComponent:
# 预编译命令正则(仅 command 类型) # 预编译命令正则(仅 command 类型)
self._compiled_pattern: re.Pattern | None = None self._compiled_pattern: re.Pattern | None = None
if component_type == "command": if component_type == "command":
pattern = metadata.get("command_pattern", "") if pattern := metadata.get("command_pattern", ""):
if pattern:
try: try:
self._compiled_pattern = re.compile(pattern) self._compiled_pattern = re.compile(pattern)
except re.error as e: except re.error as e:
@@ -120,8 +119,7 @@ class ComponentRegistry:
comps = self._by_plugin.pop(plugin_id, []) comps = self._by_plugin.pop(plugin_id, [])
for comp in comps: for comp in comps:
self._components.pop(comp.full_name, None) self._components.pop(comp.full_name, None)
type_dict = self._by_type.get(comp.component_type) if type_dict := self._by_type.get(comp.component_type):
if type_dict:
type_dict.pop(comp.full_name, None) type_dict.pop(comp.full_name, None)
return len(comps) return len(comps)
@@ -162,9 +160,7 @@ class ComponentRegistry:
) -> list[RegisteredComponent]: ) -> list[RegisteredComponent]:
"""按插件查询。""" """按插件查询。"""
comps = self._by_plugin.get(plugin_id, []) comps = self._by_plugin.get(plugin_id, [])
if enabled_only: return [c for c in comps if c.enabled] if enabled_only else list(comps)
return [c for c in comps if c.enabled]
return list(comps)
def find_command_by_text(self, text: str) -> RegisteredComponent | None: def find_command_by_text(self, text: str) -> RegisteredComponent | None:
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。""" """通过文本匹配命令正则,返回第一个匹配的 command 组件。"""

View File

@@ -93,7 +93,7 @@ class RPCServer:
self._running = False self._running = False
# 取消所有 pending 请求 # 取消所有 pending 请求
for req_id, future in self._pending_requests.items(): for future in self._pending_requests.values():
if not future.done(): if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear() self._pending_requests.clear()
@@ -162,16 +162,15 @@ class RPCServer:
# 等待响应 # 等待响应
timeout_sec = timeout_ms / 1000.0 timeout_sec = timeout_ms / 1000.0
response = await asyncio.wait_for(future, timeout=timeout_sec) return await asyncio.wait_for(future, timeout=timeout_sec)
return response
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None) self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
except Exception as e: except Exception as e:
self._pending_requests.pop(request_id, None) self._pending_requests.pop(request_id, None)
if isinstance(e, RPCError): if isinstance(e, RPCError):
raise raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None: async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
"""向 Runner 发送单向事件(不等待响应)""" """向 Runner 发送单向事件(不等待响应)"""
@@ -338,8 +337,7 @@ class RPCServer:
async def _handle_event(self, envelope: Envelope) -> None: async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Runner 的事件""" """处理来自 Runner 的事件"""
handler = self._method_handlers.get(envelope.method) if handler := self._method_handlers.get(envelope.method):
if handler:
try: try:
await handler(envelope) await handler(envelope)
except Exception as e: except Exception as e:

View File

@@ -203,19 +203,15 @@ class PluginSupervisor:
由主进程业务逻辑调用,通过 RPC 转发给 Runner。 由主进程业务逻辑调用,通过 RPC 转发给 Runner。
""" """
try: return await self._rpc_server.send_request(
response = await self._rpc_server.send_request( method=method,
method=method, plugin_id=plugin_id,
plugin_id=plugin_id, payload={
payload={ "component_name": component_name,
"component_name": component_name, "args": args or {},
"args": args or {}, },
}, timeout_ms=timeout_ms,
timeout_ms=timeout_ms, )
)
return response
except RPCError:
raise
async def reload_plugins(self, reason: str = "manual") -> None: async def reload_plugins(self, reason: str = "manual") -> None:
"""热重载所有插件(进程级 generation 切换) """热重载所有插件(进程级 generation 切换)

View File

@@ -17,7 +17,7 @@
- modification_log: 消息修改审计 - modification_log: 消息修改审计
""" """
from typing import Any, Callable, Awaitable, Optional from typing import Any, Awaitable, Callable
import asyncio import asyncio
import logging import logging
@@ -273,10 +273,9 @@ class WorkflowExecutor:
""" """
for key, expected in filter_cond.items(): for key, expected in filter_cond.items():
actual = message.get(key) actual = message.get(key)
if isinstance(expected, list): if (isinstance(expected, list) and actual not in expected) or (
if actual not in expected: not isinstance(expected, list) and actual != expected
return False ):
elif actual != expected:
return False return False
return True return True
@@ -376,12 +375,11 @@ class WorkflowExecutor:
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try: try:
resp = await invoke_fn(matched.plugin_id, matched.name, { return await invoke_fn(matched.plugin_id, matched.name, {
"text": plain_text, "text": plain_text,
"message": message, "message": message,
"trace_id": ctx.trace_id, "trace_id": ctx.trace_id,
}) })
return resp
except Exception as e: except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}") ctx.errors.append(f"command:{matched.full_name}: {e}")
@@ -390,8 +388,4 @@ class WorkflowExecutor:
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]: def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
"""返回 new 中与 old 不同的 key 列表。""" """返回 new 中与 old 不同的 key 列表。"""
changed = [] return [k for k, v in new.items() if k not in old or old[k] != v]
for k in new:
if k not in old or old[k] != new[k]:
changed.append(k)
return changed

View File

@@ -46,8 +46,7 @@ class PluginMeta:
if isinstance(dep, str): if isinstance(dep, str):
result.append(dep.strip()) result.append(dep.strip())
elif isinstance(dep, dict): elif isinstance(dep, dict):
name = str(dep.get("name", "")).strip() if name := str(dep.get("name", "")).strip():
if name:
result.append(name) result.append(name)
return result return result
@@ -121,8 +120,7 @@ class PluginLoader:
for plugin_id in load_order: for plugin_id in load_order:
plugin_dir, manifest, plugin_path = candidates[plugin_id] plugin_dir, manifest, plugin_path = candidates[plugin_id]
try: try:
meta = self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path) if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
if meta:
self._loaded_plugins[meta.plugin_id] = meta self._loaded_plugins[meta.plugin_id] = meta
results.append(meta) results.append(meta)
except Exception as e: except Exception as e:

View File

@@ -11,12 +11,12 @@
from typing import Any, Callable, Awaitable from typing import Any, Callable, Awaitable
import asyncio import asyncio
import contextlib
import logging import logging
import uuid import uuid
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import (
PROTOCOL_VERSION,
Envelope, Envelope,
HelloPayload, HelloPayload,
HelloResponsePayload, HelloResponsePayload,
@@ -129,10 +129,8 @@ class RPCClient:
self._running = False self._running = False
if self._recv_task: if self._recv_task:
self._recv_task.cancel() self._recv_task.cancel()
try: with contextlib.suppress(asyncio.CancelledError):
await self._recv_task await self._recv_task
except asyncio.CancelledError:
pass
self._recv_task = None self._recv_task = None
# 取消所有 pending 请求 # 取消所有 pending 请求
@@ -176,16 +174,15 @@ class RPCClient:
await self._connection.send_frame(data) await self._connection.send_frame(data)
timeout_sec = timeout_ms / 1000.0 timeout_sec = timeout_ms / 1000.0
response = await asyncio.wait_for(future, timeout=timeout_sec) return await asyncio.wait_for(future, timeout=timeout_sec)
return response
except asyncio.TimeoutError: except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None) self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
except Exception as e: except Exception as e:
self._pending_requests.pop(request_id, None) self._pending_requests.pop(request_id, None)
if isinstance(e, RPCError): if isinstance(e, RPCError):
raise raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
# ─── 内部方法 ────────────────────────────────────────────── # ─── 内部方法 ──────────────────────────────────────────────
@@ -249,8 +246,7 @@ class RPCClient:
async def _handle_event(self, envelope: Envelope) -> None: async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Host 的事件""" """处理来自 Host 的事件"""
handler = self._method_handlers.get(envelope.method) if handler := self._method_handlers.get(envelope.method):
if handler:
try: try:
await handler(envelope) await handler(envelope)
except Exception as e: except Exception as e:

View File

@@ -12,6 +12,7 @@
from typing import Any from typing import Any
import asyncio import asyncio
import contextlib
import logging import logging
import os import os
import signal import signal
@@ -92,11 +93,9 @@ class PluginRunner:
await self._register_plugin(meta) await self._register_plugin(meta)
# 5. 等待直到收到关停信号 # 5. 等待直到收到关停信号
try: with contextlib.suppress(asyncio.CancelledError):
while not self._shutting_down: while not self._shutting_down:
await asyncio.sleep(1.0) await asyncio.sleep(1.0)
except asyncio.CancelledError:
pass
# 6. 断开连接 # 6. 断开连接
await self._rpc_client.disconnect() await self._rpc_client.disconnect()
@@ -122,13 +121,15 @@ class PluginRunner:
# 从插件实例获取组件声明SDK 插件须实现 get_components 方法) # 从插件实例获取组件声明SDK 插件须实现 get_components 方法)
if hasattr(instance, "get_components"): if hasattr(instance, "get_components"):
for comp_info in instance.get_components(): components.extend(
components.append(ComponentDeclaration( ComponentDeclaration(
name=comp_info.get("name", ""), name=comp_info.get("name", ""),
component_type=comp_info.get("type", ""), component_type=comp_info.get("type", ""),
plugin_id=meta.plugin_id, plugin_id=meta.plugin_id,
metadata=comp_info.get("metadata", {}), metadata=comp_info.get("metadata", {}),
)) )
for comp_info in instance.get_components()
)
reg_payload = RegisterComponentsPayload( reg_payload = RegisterComponentsPayload(
plugin_id=meta.plugin_id, plugin_id=meta.plugin_id,
@@ -219,10 +220,7 @@ class PluginRunner:
raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args) raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
# 规范化返回值 # 规范化返回值
if raw is None: if isinstance(raw, str):
result = {"hook_result": "continue"}
elif isinstance(raw, str):
# 允许直接返回 hook_result 字符串
result = {"hook_result": raw} result = {"hook_result": raw}
elif isinstance(raw, dict): elif isinstance(raw, dict):
result = raw result = raw
@@ -298,8 +296,7 @@ def _isolate_sys_path(plugin_dirs: list[str]) -> None:
# 保留: 标准库路径 + site-packages含 SDK 和依赖) # 保留: 标准库路径 + site-packages含 SDK 和依赖)
stdlib_paths = set() stdlib_paths = set()
for key in ("stdlib", "platstdlib", "purelib", "platlib"): for key in ("stdlib", "platstdlib", "purelib", "platlib"):
path = sysconfig.get_path(key) if path := sysconfig.get_path(key):
if path:
stdlib_paths.add(os.path.normpath(path)) stdlib_paths.add(os.path.normpath(path))
allowed = set() allowed = set()