Files
mai-bot/src/plugin_runtime/host/event_dispatcher.py

161 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Host-side EventDispatcher
负责:
1. 按事件类型查询已注册的 event_handler通过 ComponentRegistry
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞intercept_message和非阻塞分发
4. 事件结果历史记录
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import asyncio
from src.common.logger import get_logger
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
self,
handler_name: str,
success: bool = True,
continue_processing: bool = True,
modified_message: Optional[Dict[str, Any]] = None,
custom_result: Any = None,
):
self.handler_name = handler_name
self.success = success
self.continue_processing = continue_processing
self.modified_message = modified_message
self.custom_result = custom_result
class EventDispatcher:
"""Host-side 事件分发器
由业务层调用 dispatch_event()
内部通过 ComponentRegistry 查询 handler
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry: ComponentRegistry = registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
self._background_tasks: Set[asyncio.Task] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, [])
def clear_history(self, event_type: str) -> None:
if event_type in self._result_history:
self._result_history[event_type] = []
async def dispatch_event(
self,
event_type: str,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler。
Args:
event_type: 事件类型字符串
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
message: MaiMessages 序列化后的 dict可选
extra_args: 额外参数
Returns:
(should_continue, modified_message_dict)
"""
handlers = self._registry.get_event_handlers(event_type)
if not handlers:
return True, None
should_continue = True
modified_message: Optional[Dict[str, Any]] = None
intercept_handlers: List[RegisteredComponent] = []
async_handlers: List[RegisteredComponent] = []
for handler in handlers:
if handler.metadata.get("intercept_message", False):
intercept_handlers.append(handler)
else:
async_handlers.append(handler)
for handler in intercept_handlers:
args = {
"event_type": event_type,
"message": modified_message or message,
**(extra_args or {}),
}
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
if result and result.modified_message:
modified_message = result.modified_message
if should_continue:
final_message = modified_message or message
for handler in async_handlers:
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return should_continue, modified_message
async def _invoke_handler(
self,
invoke_fn: InvokeFn,
handler: RegisteredComponent,
args: Dict[str, Any],
event_type: str,
) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。"""
try:
resp = await invoke_fn(handler.plugin_id, handler.name, args)
result = EventResult(
handler_name=handler.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(
handler_name=handler.full_name,
success=False,
continue_processing=True,
)
if event_type in self._history_enabled:
self._result_history.setdefault(event_type, []).append(result)
return result