重构事件总线和插件运行时,优化消息处理逻辑,新增 IPC 传输字典转换功能,改进组件管理协议
This commit is contained in:
@@ -9,7 +9,6 @@
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from dataclasses import fields
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -101,20 +100,30 @@ class EventBus:
|
||||
|
||||
continue_flag = True
|
||||
current_message = message.deepcopy() if message else None
|
||||
intercept_handlers: List[_HandlerEntry] = []
|
||||
async_handlers: List[_HandlerEntry] = []
|
||||
|
||||
for entry in handlers:
|
||||
if entry.intercept:
|
||||
try:
|
||||
should_continue, modified = await entry.handler(current_message)
|
||||
if modified is not None:
|
||||
current_message = modified
|
||||
if not should_continue:
|
||||
continue_flag = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
|
||||
intercept_handlers.append(entry)
|
||||
else:
|
||||
self._fire_and_forget(entry, event_type, current_message)
|
||||
async_handlers.append(entry)
|
||||
|
||||
for entry in intercept_handlers:
|
||||
try:
|
||||
should_continue, modified = await entry.handler(current_message)
|
||||
if modified is not None:
|
||||
current_message = modified
|
||||
if not should_continue:
|
||||
continue_flag = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
|
||||
|
||||
if continue_flag:
|
||||
for entry in async_handlers:
|
||||
async_message = current_message.deepcopy() if current_message else None
|
||||
self._fire_and_forget(entry, event_type, async_message)
|
||||
|
||||
# 桥接到 IPC 插件运行时
|
||||
continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
|
||||
@@ -138,7 +147,7 @@ class EventBus:
|
||||
event_type: EventType | str,
|
||||
message: Optional[MaiMessages],
|
||||
) -> None:
|
||||
"""创建异步任务执行非拦截型 handler"""
|
||||
"""创建异步任务执行非拦截型 handler。"""
|
||||
try:
|
||||
task = asyncio.create_task(entry.handler(message))
|
||||
task.set_name(entry.name)
|
||||
@@ -179,7 +188,7 @@ class EventBus:
|
||||
return continue_flag, message
|
||||
|
||||
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
|
||||
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
|
||||
message_dict = message.to_transport_dict() if message else None
|
||||
|
||||
new_continue, modified_dict = await prm.bridge_event(
|
||||
event_type_value=event_value,
|
||||
@@ -197,12 +206,7 @@ class EventBus:
|
||||
@staticmethod
|
||||
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
|
||||
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
|
||||
updated_message = message.deepcopy()
|
||||
valid_fields = {field.name for field in fields(MaiMessages)}
|
||||
for key, value in modified_dict.items():
|
||||
if key in valid_fields:
|
||||
setattr(updated_message, key, value)
|
||||
return updated_message
|
||||
return message.apply_transport_update(modified_dict)
|
||||
|
||||
|
||||
class _HandlerEntry:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import copy
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from maim_message import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
@@ -313,6 +313,86 @@ class MaiMessages:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def to_transport_dict(self) -> Dict[str, Any]:
|
||||
"""将消息转换为可通过 IPC 传输的纯字典。"""
|
||||
return {
|
||||
field_info.name: self._serialize_transport_value(getattr(self, field_info.name))
|
||||
for field_info in fields(MaiMessages)
|
||||
if field_info.name != "_modify_flags"
|
||||
}
|
||||
|
||||
def apply_transport_update(self, modified_dict: Dict[str, Any]) -> "MaiMessages":
|
||||
"""将 IPC 返回的消息字典回写到当前消息对象。"""
|
||||
updated_message = self.deepcopy()
|
||||
valid_fields = {field_info.name for field_info in fields(MaiMessages) if field_info.name != "_modify_flags"}
|
||||
|
||||
for key, value in modified_dict.items():
|
||||
if key not in valid_fields:
|
||||
continue
|
||||
deserialized = self._deserialize_transport_field(key, value)
|
||||
setattr(updated_message, key, deserialized)
|
||||
|
||||
if key == "message_segments":
|
||||
updated_message._modify_flags.modify_message_segments = True
|
||||
elif key == "plain_text":
|
||||
updated_message._modify_flags.modify_plain_text = True
|
||||
elif key == "llm_prompt":
|
||||
updated_message._modify_flags.modify_llm_prompt = True
|
||||
elif key == "llm_response_content":
|
||||
updated_message._modify_flags.modify_llm_response_content = True
|
||||
elif key == "llm_response_reasoning":
|
||||
updated_message._modify_flags.modify_llm_response_reasoning = True
|
||||
|
||||
return updated_message
|
||||
|
||||
@staticmethod
|
||||
def _serialize_transport_value(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, list):
|
||||
return [MaiMessages._serialize_transport_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [MaiMessages._serialize_transport_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: MaiMessages._serialize_transport_value(item) for key, item in value.items()}
|
||||
if hasattr(value, "__dict__"):
|
||||
return {
|
||||
key: MaiMessages._serialize_transport_value(item)
|
||||
for key, item in vars(value).items()
|
||||
if not key.startswith("_")
|
||||
}
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
|
||||
if field_name == "message_segments" and isinstance(value, list):
|
||||
deserialized_segments: List[Seg] = []
|
||||
for segment in value:
|
||||
if isinstance(segment, Seg):
|
||||
deserialized_segments.append(segment)
|
||||
elif isinstance(segment, dict) and "type" in segment:
|
||||
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
|
||||
return deserialized_segments
|
||||
|
||||
if field_name == "llm_response_tool_call" and isinstance(value, list):
|
||||
deserialized_tool_calls: List[ToolCall] = []
|
||||
for tool_call in value:
|
||||
if isinstance(tool_call, ToolCall):
|
||||
deserialized_tool_calls.append(tool_call)
|
||||
elif isinstance(tool_call, dict):
|
||||
deserialized_tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=str(tool_call.get("call_id", "")),
|
||||
func_name=str(tool_call.get("func_name", "")),
|
||||
args=tool_call.get("args"),
|
||||
)
|
||||
)
|
||||
return deserialized_tool_calls
|
||||
|
||||
return value
|
||||
|
||||
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
|
||||
"""
|
||||
修改消息段列表
|
||||
|
||||
Reference in New Issue
Block a user