161 lines
5.8 KiB
Python
161 lines
5.8 KiB
Python
"""Host-side EventDispatcher
|
||
|
||
负责:
|
||
1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry)
|
||
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
|
||
3. 支持阻塞(intercept_message)和非阻塞分发
|
||
4. 事件结果历史记录
|
||
"""
|
||
|
||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||
|
||
import asyncio
|
||
|
||
from src.common.logger import get_logger
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||
|
||
logger = get_logger("plugin_runtime.host.event_dispatcher")
|
||
|
||
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
|
||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||
|
||
|
||
class EventResult:
|
||
"""单个 EventHandler 的执行结果"""
|
||
|
||
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
||
|
||
def __init__(
|
||
self,
|
||
handler_name: str,
|
||
success: bool = True,
|
||
continue_processing: bool = True,
|
||
modified_message: Optional[Dict[str, Any]] = None,
|
||
custom_result: Any = None,
|
||
):
|
||
self.handler_name = handler_name
|
||
self.success = success
|
||
self.continue_processing = continue_processing
|
||
self.modified_message = modified_message
|
||
self.custom_result = custom_result
|
||
|
||
|
||
class EventDispatcher:
|
||
"""Host-side 事件分发器
|
||
|
||
由业务层调用 dispatch_event(),
|
||
内部通过 ComponentRegistry 查询 handler,
|
||
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||
"""
|
||
|
||
def __init__(self, registry: ComponentRegistry) -> None:
|
||
self._registry: ComponentRegistry = registry
|
||
self._result_history: Dict[str, List[EventResult]] = {}
|
||
self._history_enabled: Set[str] = set()
|
||
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
|
||
self._background_tasks: Set[asyncio.Task] = set()
|
||
|
||
def enable_history(self, event_type: str) -> None:
|
||
self._history_enabled.add(event_type)
|
||
self._result_history.setdefault(event_type, [])
|
||
|
||
def get_history(self, event_type: str) -> List[EventResult]:
|
||
return self._result_history.get(event_type, [])
|
||
|
||
def clear_history(self, event_type: str) -> None:
|
||
if event_type in self._result_history:
|
||
self._result_history[event_type] = []
|
||
|
||
async def dispatch_event(
|
||
self,
|
||
event_type: str,
|
||
invoke_fn: InvokeFn,
|
||
message: Optional[Dict[str, Any]] = None,
|
||
extra_args: Optional[Dict[str, Any]] = None,
|
||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||
"""分发事件到所有对应 handler。
|
||
|
||
Args:
|
||
event_type: 事件类型字符串
|
||
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
|
||
message: MaiMessages 序列化后的 dict(可选)
|
||
extra_args: 额外参数
|
||
|
||
Returns:
|
||
(should_continue, modified_message_dict)
|
||
"""
|
||
handlers = self._registry.get_event_handlers(event_type)
|
||
if not handlers:
|
||
return True, None
|
||
|
||
should_continue = True
|
||
modified_message: Optional[Dict[str, Any]] = None
|
||
intercept_handlers: List[RegisteredComponent] = []
|
||
async_handlers: List[RegisteredComponent] = []
|
||
|
||
for handler in handlers:
|
||
if handler.metadata.get("intercept_message", False):
|
||
intercept_handlers.append(handler)
|
||
else:
|
||
async_handlers.append(handler)
|
||
|
||
for handler in intercept_handlers:
|
||
args = {
|
||
"event_type": event_type,
|
||
"message": modified_message or message,
|
||
**(extra_args or {}),
|
||
}
|
||
|
||
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
|
||
if result and not result.continue_processing:
|
||
should_continue = False
|
||
break
|
||
if result and result.modified_message:
|
||
modified_message = result.modified_message
|
||
|
||
if should_continue:
|
||
final_message = modified_message or message
|
||
for handler in async_handlers:
|
||
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
|
||
args = {
|
||
"event_type": event_type,
|
||
"message": async_message,
|
||
**(extra_args or {}),
|
||
}
|
||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
|
||
self._background_tasks.add(task)
|
||
task.add_done_callback(self._background_tasks.discard)
|
||
|
||
return should_continue, modified_message
|
||
|
||
async def _invoke_handler(
|
||
self,
|
||
invoke_fn: InvokeFn,
|
||
handler: RegisteredComponent,
|
||
args: Dict[str, Any],
|
||
event_type: str,
|
||
) -> Optional[EventResult]:
|
||
"""调用单个 handler 并收集结果。"""
|
||
try:
|
||
resp = await invoke_fn(handler.plugin_id, handler.name, args)
|
||
result = EventResult(
|
||
handler_name=handler.full_name,
|
||
success=resp.get("success", True),
|
||
continue_processing=resp.get("continue_processing", True),
|
||
modified_message=resp.get("modified_message"),
|
||
custom_result=resp.get("custom_result"),
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
|
||
result = EventResult(
|
||
handler_name=handler.full_name,
|
||
success=False,
|
||
continue_processing=True,
|
||
)
|
||
|
||
if event_type in self._history_enabled:
|
||
self._result_history.setdefault(event_type, []).append(result)
|
||
|
||
return result
|