重构事件总线和插件运行时,优化消息处理逻辑,新增 IPC 传输字典转换功能,改进组件管理协议

This commit is contained in:
DrSmoothl
2026-03-14 01:39:59 +08:00
parent 84212e8e95
commit 2c330e3902
7 changed files with 197 additions and 56 deletions

View File

@@ -9,7 +9,6 @@
import asyncio
import contextlib
from dataclasses import fields
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -101,20 +100,30 @@ class EventBus:
continue_flag = True
current_message = message.deepcopy() if message else None
intercept_handlers: List[_HandlerEntry] = []
async_handlers: List[_HandlerEntry] = []
for entry in handlers:
if entry.intercept:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
intercept_handlers.append(entry)
else:
self._fire_and_forget(entry, event_type, current_message)
async_handlers.append(entry)
for entry in intercept_handlers:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
if continue_flag:
for entry in async_handlers:
async_message = current_message.deepcopy() if current_message else None
self._fire_and_forget(entry, event_type, async_message)
# 桥接到 IPC 插件运行时
continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
@@ -138,7 +147,7 @@ class EventBus:
event_type: EventType | str,
message: Optional[MaiMessages],
) -> None:
"""创建异步任务执行非拦截型 handler"""
"""创建异步任务执行非拦截型 handler"""
try:
task = asyncio.create_task(entry.handler(message))
task.set_name(entry.name)
@@ -179,7 +188,7 @@ class EventBus:
return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
message_dict = message.to_transport_dict() if message else None
new_continue, modified_dict = await prm.bridge_event(
event_type_value=event_value,
@@ -197,12 +206,7 @@ class EventBus:
@staticmethod
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
updated_message = message.deepcopy()
valid_fields = {field.name for field in fields(MaiMessages)}
for key, value in modified_dict.items():
if key in valid_fields:
setattr(updated_message, key, value)
return updated_message
return message.apply_transport_update(modified_dict)
class _HandlerEntry:

View File

@@ -1,8 +1,8 @@
import copy
import warnings
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
@@ -313,6 +313,86 @@ class MaiMessages:
def deepcopy(self):
return copy.deepcopy(self)
def to_transport_dict(self) -> Dict[str, Any]:
"""将消息转换为可通过 IPC 传输的纯字典。"""
return {
field_info.name: self._serialize_transport_value(getattr(self, field_info.name))
for field_info in fields(MaiMessages)
if field_info.name != "_modify_flags"
}
def apply_transport_update(self, modified_dict: Dict[str, Any]) -> "MaiMessages":
"""将 IPC 返回的消息字典回写到当前消息对象。"""
updated_message = self.deepcopy()
valid_fields = {field_info.name for field_info in fields(MaiMessages) if field_info.name != "_modify_flags"}
for key, value in modified_dict.items():
if key not in valid_fields:
continue
deserialized = self._deserialize_transport_field(key, value)
setattr(updated_message, key, deserialized)
if key == "message_segments":
updated_message._modify_flags.modify_message_segments = True
elif key == "plain_text":
updated_message._modify_flags.modify_plain_text = True
elif key == "llm_prompt":
updated_message._modify_flags.modify_llm_prompt = True
elif key == "llm_response_content":
updated_message._modify_flags.modify_llm_response_content = True
elif key == "llm_response_reasoning":
updated_message._modify_flags.modify_llm_response_reasoning = True
return updated_message
@staticmethod
def _serialize_transport_value(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, list):
return [MaiMessages._serialize_transport_value(item) for item in value]
if isinstance(value, tuple):
return [MaiMessages._serialize_transport_value(item) for item in value]
if isinstance(value, dict):
return {key: MaiMessages._serialize_transport_value(item) for key, item in value.items()}
if hasattr(value, "__dict__"):
return {
key: MaiMessages._serialize_transport_value(item)
for key, item in vars(value).items()
if not key.startswith("_")
}
return value
@staticmethod
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
if field_name == "message_segments" and isinstance(value, list):
deserialized_segments: List[Seg] = []
for segment in value:
if isinstance(segment, Seg):
deserialized_segments.append(segment)
elif isinstance(segment, dict) and "type" in segment:
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
return deserialized_segments
if field_name == "llm_response_tool_call" and isinstance(value, list):
deserialized_tool_calls: List[ToolCall] = []
for tool_call in value:
if isinstance(tool_call, ToolCall):
deserialized_tool_calls.append(tool_call)
elif isinstance(tool_call, dict):
deserialized_tool_calls.append(
ToolCall(
call_id=str(tool_call.get("call_id", "")),
func_name=str(tool_call.get("func_name", "")),
args=tool_call.get("args"),
)
)
return deserialized_tool_calls
return value
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
"""
修改消息段列表