refactor: 优化代码结构,简化条件判断和异常处理
This commit is contained in:
@@ -886,7 +886,7 @@ class TestWorkflowExecutor:
|
|||||||
await executor.execute(
|
await executor.execute(
|
||||||
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
|
mock_invoke, message={"plain_text": "hi", "chat_type": "group"}
|
||||||
)
|
)
|
||||||
assert call_log == []
|
assert not call_log
|
||||||
|
|
||||||
# 匹配 filter —— hook 应被调用
|
# 匹配 filter —— hook 应被调用
|
||||||
await executor.execute(
|
await executor.execute(
|
||||||
|
|||||||
@@ -42,8 +42,7 @@ class RegisteredComponent:
|
|||||||
# 预编译命令正则(仅 command 类型)
|
# 预编译命令正则(仅 command 类型)
|
||||||
self._compiled_pattern: re.Pattern | None = None
|
self._compiled_pattern: re.Pattern | None = None
|
||||||
if component_type == "command":
|
if component_type == "command":
|
||||||
pattern = metadata.get("command_pattern", "")
|
if pattern := metadata.get("command_pattern", ""):
|
||||||
if pattern:
|
|
||||||
try:
|
try:
|
||||||
self._compiled_pattern = re.compile(pattern)
|
self._compiled_pattern = re.compile(pattern)
|
||||||
except re.error as e:
|
except re.error as e:
|
||||||
@@ -120,8 +119,7 @@ class ComponentRegistry:
|
|||||||
comps = self._by_plugin.pop(plugin_id, [])
|
comps = self._by_plugin.pop(plugin_id, [])
|
||||||
for comp in comps:
|
for comp in comps:
|
||||||
self._components.pop(comp.full_name, None)
|
self._components.pop(comp.full_name, None)
|
||||||
type_dict = self._by_type.get(comp.component_type)
|
if type_dict := self._by_type.get(comp.component_type):
|
||||||
if type_dict:
|
|
||||||
type_dict.pop(comp.full_name, None)
|
type_dict.pop(comp.full_name, None)
|
||||||
return len(comps)
|
return len(comps)
|
||||||
|
|
||||||
@@ -162,9 +160,7 @@ class ComponentRegistry:
|
|||||||
) -> list[RegisteredComponent]:
|
) -> list[RegisteredComponent]:
|
||||||
"""按插件查询。"""
|
"""按插件查询。"""
|
||||||
comps = self._by_plugin.get(plugin_id, [])
|
comps = self._by_plugin.get(plugin_id, [])
|
||||||
if enabled_only:
|
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||||
return [c for c in comps if c.enabled]
|
|
||||||
return list(comps)
|
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
|
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
|
||||||
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
|
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class RPCServer:
|
|||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
# 取消所有 pending 请求
|
# 取消所有 pending 请求
|
||||||
for req_id, future in self._pending_requests.items():
|
for future in self._pending_requests.values():
|
||||||
if not future.done():
|
if not future.done():
|
||||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||||
self._pending_requests.clear()
|
self._pending_requests.clear()
|
||||||
@@ -162,16 +162,15 @@ class RPCServer:
|
|||||||
|
|
||||||
# 等待响应
|
# 等待响应
|
||||||
timeout_sec = timeout_ms / 1000.0
|
timeout_sec = timeout_ms / 1000.0
|
||||||
response = await asyncio.wait_for(future, timeout=timeout_sec)
|
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||||||
return response
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self._pending_requests.pop(request_id, None)
|
self._pending_requests.pop(request_id, None)
|
||||||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)")
|
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._pending_requests.pop(request_id, None)
|
self._pending_requests.pop(request_id, None)
|
||||||
if isinstance(e, RPCError):
|
if isinstance(e, RPCError):
|
||||||
raise
|
raise
|
||||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e))
|
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||||
|
|
||||||
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
|
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
|
||||||
"""向 Runner 发送单向事件(不等待响应)"""
|
"""向 Runner 发送单向事件(不等待响应)"""
|
||||||
@@ -338,8 +337,7 @@ class RPCServer:
|
|||||||
|
|
||||||
async def _handle_event(self, envelope: Envelope) -> None:
|
async def _handle_event(self, envelope: Envelope) -> None:
|
||||||
"""处理来自 Runner 的事件"""
|
"""处理来自 Runner 的事件"""
|
||||||
handler = self._method_handlers.get(envelope.method)
|
if handler := self._method_handlers.get(envelope.method):
|
||||||
if handler:
|
|
||||||
try:
|
try:
|
||||||
await handler(envelope)
|
await handler(envelope)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -203,19 +203,15 @@ class PluginSupervisor:
|
|||||||
|
|
||||||
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
|
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
|
||||||
"""
|
"""
|
||||||
try:
|
return await self._rpc_server.send_request(
|
||||||
response = await self._rpc_server.send_request(
|
method=method,
|
||||||
method=method,
|
plugin_id=plugin_id,
|
||||||
plugin_id=plugin_id,
|
payload={
|
||||||
payload={
|
"component_name": component_name,
|
||||||
"component_name": component_name,
|
"args": args or {},
|
||||||
"args": args or {},
|
},
|
||||||
},
|
timeout_ms=timeout_ms,
|
||||||
timeout_ms=timeout_ms,
|
)
|
||||||
)
|
|
||||||
return response
|
|
||||||
except RPCError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def reload_plugins(self, reason: str = "manual") -> None:
|
async def reload_plugins(self, reason: str = "manual") -> None:
|
||||||
"""热重载所有插件(进程级 generation 切换)
|
"""热重载所有插件(进程级 generation 切换)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
- modification_log: 消息修改审计
|
- modification_log: 消息修改审计
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Callable, Awaitable, Optional
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
@@ -273,10 +273,9 @@ class WorkflowExecutor:
|
|||||||
"""
|
"""
|
||||||
for key, expected in filter_cond.items():
|
for key, expected in filter_cond.items():
|
||||||
actual = message.get(key)
|
actual = message.get(key)
|
||||||
if isinstance(expected, list):
|
if (isinstance(expected, list) and actual not in expected) or (
|
||||||
if actual not in expected:
|
not isinstance(expected, list) and actual != expected
|
||||||
return False
|
):
|
||||||
elif actual != expected:
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -376,12 +375,11 @@ class WorkflowExecutor:
|
|||||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await invoke_fn(matched.plugin_id, matched.name, {
|
return await invoke_fn(matched.plugin_id, matched.name, {
|
||||||
"text": plain_text,
|
"text": plain_text,
|
||||||
"message": message,
|
"message": message,
|
||||||
"trace_id": ctx.trace_id,
|
"trace_id": ctx.trace_id,
|
||||||
})
|
})
|
||||||
return resp
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||||||
@@ -390,8 +388,4 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
|
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
|
||||||
"""返回 new 中与 old 不同的 key 列表。"""
|
"""返回 new 中与 old 不同的 key 列表。"""
|
||||||
changed = []
|
return [k for k, v in new.items() if k not in old or old[k] != v]
|
||||||
for k in new:
|
|
||||||
if k not in old or old[k] != new[k]:
|
|
||||||
changed.append(k)
|
|
||||||
return changed
|
|
||||||
|
|||||||
@@ -46,8 +46,7 @@ class PluginMeta:
|
|||||||
if isinstance(dep, str):
|
if isinstance(dep, str):
|
||||||
result.append(dep.strip())
|
result.append(dep.strip())
|
||||||
elif isinstance(dep, dict):
|
elif isinstance(dep, dict):
|
||||||
name = str(dep.get("name", "")).strip()
|
if name := str(dep.get("name", "")).strip():
|
||||||
if name:
|
|
||||||
result.append(name)
|
result.append(name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -121,8 +120,7 @@ class PluginLoader:
|
|||||||
for plugin_id in load_order:
|
for plugin_id in load_order:
|
||||||
plugin_dir, manifest, plugin_path = candidates[plugin_id]
|
plugin_dir, manifest, plugin_path = candidates[plugin_id]
|
||||||
try:
|
try:
|
||||||
meta = self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
|
if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
|
||||||
if meta:
|
|
||||||
self._loaded_plugins[meta.plugin_id] = meta
|
self._loaded_plugins[meta.plugin_id] = meta
|
||||||
results.append(meta)
|
results.append(meta)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -11,12 +11,12 @@
|
|||||||
from typing import Any, Callable, Awaitable
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
|
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
|
||||||
from src.plugin_runtime.protocol.envelope import (
|
from src.plugin_runtime.protocol.envelope import (
|
||||||
PROTOCOL_VERSION,
|
|
||||||
Envelope,
|
Envelope,
|
||||||
HelloPayload,
|
HelloPayload,
|
||||||
HelloResponsePayload,
|
HelloResponsePayload,
|
||||||
@@ -129,10 +129,8 @@ class RPCClient:
|
|||||||
self._running = False
|
self._running = False
|
||||||
if self._recv_task:
|
if self._recv_task:
|
||||||
self._recv_task.cancel()
|
self._recv_task.cancel()
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._recv_task
|
await self._recv_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._recv_task = None
|
self._recv_task = None
|
||||||
|
|
||||||
# 取消所有 pending 请求
|
# 取消所有 pending 请求
|
||||||
@@ -176,16 +174,15 @@ class RPCClient:
|
|||||||
await self._connection.send_frame(data)
|
await self._connection.send_frame(data)
|
||||||
|
|
||||||
timeout_sec = timeout_ms / 1000.0
|
timeout_sec = timeout_ms / 1000.0
|
||||||
response = await asyncio.wait_for(future, timeout=timeout_sec)
|
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||||||
return response
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self._pending_requests.pop(request_id, None)
|
self._pending_requests.pop(request_id, None)
|
||||||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)")
|
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._pending_requests.pop(request_id, None)
|
self._pending_requests.pop(request_id, None)
|
||||||
if isinstance(e, RPCError):
|
if isinstance(e, RPCError):
|
||||||
raise
|
raise
|
||||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e))
|
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||||
|
|
||||||
# ─── 内部方法 ──────────────────────────────────────────────
|
# ─── 内部方法 ──────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -249,8 +246,7 @@ class RPCClient:
|
|||||||
|
|
||||||
async def _handle_event(self, envelope: Envelope) -> None:
|
async def _handle_event(self, envelope: Envelope) -> None:
|
||||||
"""处理来自 Host 的事件"""
|
"""处理来自 Host 的事件"""
|
||||||
handler = self._method_handlers.get(envelope.method)
|
if handler := self._method_handlers.get(envelope.method):
|
||||||
if handler:
|
|
||||||
try:
|
try:
|
||||||
await handler(envelope)
|
await handler(envelope)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@@ -92,11 +93,9 @@ class PluginRunner:
|
|||||||
await self._register_plugin(meta)
|
await self._register_plugin(meta)
|
||||||
|
|
||||||
# 5. 等待直到收到关停信号
|
# 5. 等待直到收到关停信号
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
while not self._shutting_down:
|
while not self._shutting_down:
|
||||||
await asyncio.sleep(1.0)
|
await asyncio.sleep(1.0)
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 6. 断开连接
|
# 6. 断开连接
|
||||||
await self._rpc_client.disconnect()
|
await self._rpc_client.disconnect()
|
||||||
@@ -122,13 +121,15 @@ class PluginRunner:
|
|||||||
|
|
||||||
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
|
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
|
||||||
if hasattr(instance, "get_components"):
|
if hasattr(instance, "get_components"):
|
||||||
for comp_info in instance.get_components():
|
components.extend(
|
||||||
components.append(ComponentDeclaration(
|
ComponentDeclaration(
|
||||||
name=comp_info.get("name", ""),
|
name=comp_info.get("name", ""),
|
||||||
component_type=comp_info.get("type", ""),
|
component_type=comp_info.get("type", ""),
|
||||||
plugin_id=meta.plugin_id,
|
plugin_id=meta.plugin_id,
|
||||||
metadata=comp_info.get("metadata", {}),
|
metadata=comp_info.get("metadata", {}),
|
||||||
))
|
)
|
||||||
|
for comp_info in instance.get_components()
|
||||||
|
)
|
||||||
|
|
||||||
reg_payload = RegisterComponentsPayload(
|
reg_payload = RegisterComponentsPayload(
|
||||||
plugin_id=meta.plugin_id,
|
plugin_id=meta.plugin_id,
|
||||||
@@ -219,10 +220,7 @@ class PluginRunner:
|
|||||||
raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
||||||
|
|
||||||
# 规范化返回值
|
# 规范化返回值
|
||||||
if raw is None:
|
if isinstance(raw, str):
|
||||||
result = {"hook_result": "continue"}
|
|
||||||
elif isinstance(raw, str):
|
|
||||||
# 允许直接返回 hook_result 字符串
|
|
||||||
result = {"hook_result": raw}
|
result = {"hook_result": raw}
|
||||||
elif isinstance(raw, dict):
|
elif isinstance(raw, dict):
|
||||||
result = raw
|
result = raw
|
||||||
@@ -298,8 +296,7 @@ def _isolate_sys_path(plugin_dirs: list[str]) -> None:
|
|||||||
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
|
# 保留: 标准库路径 + site-packages(含 SDK 和依赖)
|
||||||
stdlib_paths = set()
|
stdlib_paths = set()
|
||||||
for key in ("stdlib", "platstdlib", "purelib", "platlib"):
|
for key in ("stdlib", "platstdlib", "purelib", "platlib"):
|
||||||
path = sysconfig.get_path(key)
|
if path := sysconfig.get_path(key):
|
||||||
if path:
|
|
||||||
stdlib_paths.add(os.path.normpath(path))
|
stdlib_paths.add(os.path.normpath(path))
|
||||||
|
|
||||||
allowed = set()
|
allowed = set()
|
||||||
|
|||||||
Reference in New Issue
Block a user