feat: 增强插件能力检查,支持 generation 校验并添加清理功能
This commit is contained in:
@@ -8,14 +8,13 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
import contextlib
|
||||
from dataclasses import fields
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import EventType, MaiMessages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
logger = get_logger("event_bus")
|
||||
|
||||
# Handler 签名:接收 MaiMessages,返回 (continue, modified_message)
|
||||
@@ -127,8 +126,7 @@ class EventBus:
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
"""取消某个 handler 的所有运行中任务"""
|
||||
tasks = self._running_tasks.pop(handler_name, [])
|
||||
remaining = [t for t in tasks if not t.done()]
|
||||
if remaining:
|
||||
if remaining := [t for t in tasks if not t.done()]:
|
||||
for t in remaining:
|
||||
t.cancel()
|
||||
await asyncio.gather(*remaining, return_exceptions=True)
|
||||
@@ -156,17 +154,14 @@ class EventBus:
|
||||
try:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
if exc := task.exception():
|
||||
logger.error(f"handler {handler_name} 异步任务异常: {exc}")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
task_list = self._running_tasks.get(handler_name, [])
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
task_list.remove(task)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def _bridge_to_ipc_runtime(
|
||||
self,
|
||||
@@ -188,17 +183,29 @@ class EventBus:
|
||||
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
|
||||
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
|
||||
|
||||
new_continue, _ = await prm.bridge_event(
|
||||
new_continue, modified_dict = await prm.bridge_event(
|
||||
event_type_value=event_value,
|
||||
message_dict=message_dict,
|
||||
)
|
||||
if not new_continue:
|
||||
continue_flag = False
|
||||
if modified_dict is not None and message is not None:
|
||||
message = self._apply_ipc_message_update(message, modified_dict)
|
||||
except Exception as e:
|
||||
logger.warning(f"桥接事件到 IPC 运行时失败: {e}")
|
||||
|
||||
return continue_flag, message
|
||||
|
||||
@staticmethod
|
||||
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
|
||||
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
|
||||
updated_message = message.deepcopy()
|
||||
valid_fields = {field.name for field in fields(MaiMessages)}
|
||||
for key, value in modified_dict.items():
|
||||
if key in valid_fields:
|
||||
setattr(updated_message, key, value)
|
||||
return updated_message
|
||||
|
||||
|
||||
class _HandlerEntry:
|
||||
"""内部 handler 条目"""
|
||||
|
||||
Reference in New Issue
Block a user