refactor: event_dispatcher

This commit is contained in:
UnCLAS-Prommer
2026-03-18 15:29:08 +08:00
committed by DrSmoothl
parent 14a0c21cbf
commit 32519c688b
2 changed files with 65 additions and 56 deletions

View File

@@ -101,6 +101,7 @@ class EventHandlerEntry(ComponentEntry):
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
self.intercept_message: bool = metadata.get("intercept_message", False)
super().__init__(name, component_type, plugin_id, metadata)
@@ -356,7 +357,7 @@ class ComponentRegistry:
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[EventHandlerEntry]:
"""查询指定事件类型的事件处理器组件。
Args:
event_type (str): 事件类型
enabled_only (bool): 是否仅返回启用的组件
@@ -400,7 +401,7 @@ class ComponentRegistry:
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
Args:
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
@@ -418,7 +419,7 @@ class ComponentRegistry:
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""

View File

@@ -4,40 +4,38 @@
1. 按事件类型查询已注册的 event_handler通过 ComponentRegistry
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞intercept_message和非阻塞分发
4. 事件结果历史记录
4. 事件结果历史记录(有上限)
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import asyncio
from src.common.logger import get_logger
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
if TYPE_CHECKING:
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, EventHandlerEntry
logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
# 每个事件类型的最大历史记录数量,防止内存无限增长
_MAX_HISTORY_LENGTH = 100
@dataclass
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
self,
handler_name: str,
success: bool = True,
continue_processing: bool = True,
modified_message: Optional[Dict[str, Any]] = None,
custom_result: Any = None,
):
self.handler_name = handler_name
self.success = success
self.continue_processing = continue_processing
self.modified_message = modified_message
self.custom_result = custom_result
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
modified_message: Optional[Dict[str, Any]] = field(default=None)
custom_result: Any = field(default=None)
class EventDispatcher:
@@ -48,8 +46,8 @@ class EventDispatcher:
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry: ComponentRegistry = registry
def __init__(self, component_registry: "ComponentRegistry") -> None:
self._component_registry: "ComponentRegistry" = component_registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
@@ -59,6 +57,10 @@ class EventDispatcher:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
def disable_history(self, event_type: str) -> None:
self._history_enabled.discard(event_type)
self._result_history.pop(event_type, None)
def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, [])
@@ -69,44 +71,46 @@ class EventDispatcher:
async def dispatch_event(
self,
event_type: str,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
supervisor: "PluginRunnerSupervisor",
message_dict: Optional[Dict[str, Any]] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler。
"""分发事件到所有对应 handler 的便捷方法
内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
Args:
event_type: 事件类型字符串
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
message: MaiMessages 序列化后的 dict可选
extra_args: 额外参数
Returns:
(should_continue, modified_message_dict)
(should_continue, modified_message_dict) (bool, Dict[str, Any] | None): (是否继续后续执行, 可选的修改后的消息字典)
"""
handlers = self._registry.get_event_handlers(event_type)
if not handlers:
handler_entries = self._component_registry.get_event_handlers(event_type)
if not handler_entries:
return True, None
should_continue = True
modified_message: Optional[Dict[str, Any]] = None
intercept_handlers: List[RegisteredComponent] = []
async_handlers: List[RegisteredComponent] = []
modified_message: Optional[Dict[str, Any]] = message_dict
intercept_handlers: List["EventHandlerEntry"] = []
non_blocking_handlers: List["EventHandlerEntry"] = []
for handler in handlers:
if handler.metadata.get("intercept_message", False):
intercept_handlers.append(handler)
for entry in handler_entries:
if entry.intercept_message:
intercept_handlers.append(entry)
else:
async_handlers.append(handler)
non_blocking_handlers.append(entry)
for handler in intercept_handlers:
for entry in intercept_handlers:
args = {
"event_type": event_type,
"message": modified_message or message,
"message": modified_message,
**(extra_args or {}),
}
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
result = await self._invoke_handler(supervisor, entry, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
@@ -114,16 +118,16 @@ class EventDispatcher:
modified_message = result.modified_message
if should_continue:
final_message = modified_message or message
for handler in async_handlers:
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
final_message = modified_message
for entry in non_blocking_handlers:
async_message = final_message.copy() if final_message else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
@@ -131,30 +135,34 @@ class EventDispatcher:
async def _invoke_handler(
self,
invoke_fn: InvokeFn,
handler: RegisteredComponent,
supervisor: "PluginRunnerSupervisor",
handler_entry: "EventHandlerEntry",
args: Dict[str, Any],
event_type: str,
) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。"""
try:
resp = await invoke_fn(handler.plugin_id, handler.name, args)
resp_envelope = await supervisor.invoke_plugin(
"plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = EventResult(
handler_name=handler.full_name,
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(
handler_name=handler.full_name,
success=False,
continue_processing=True,
)
logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
if event_type in self._history_enabled:
self._result_history.setdefault(event_type, []).append(result)
history_list = self._result_history.setdefault(event_type, [])
history_list.append(result)
# 自动清理超出限制的旧记录,防止内存无限增长
if len(history_list) > _MAX_HISTORY_LENGTH:
# 保留最新的 _MAX_HISTORY_LENGTH 条记录
self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
return result