重构事件总线和插件运行时,优化消息处理逻辑,新增 IPC 传输字典转换功能,改进组件管理协议

This commit is contained in:
DrSmoothl
2026-03-14 01:39:59 +08:00
parent 84212e8e95
commit 2c330e3902
7 changed files with 197 additions and 56 deletions

View File

@@ -9,7 +9,6 @@
import asyncio import asyncio
import contextlib import contextlib
from dataclasses import fields
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -101,20 +100,30 @@ class EventBus:
continue_flag = True continue_flag = True
current_message = message.deepcopy() if message else None current_message = message.deepcopy() if message else None
intercept_handlers: List[_HandlerEntry] = []
async_handlers: List[_HandlerEntry] = []
for entry in handlers: for entry in handlers:
if entry.intercept: if entry.intercept:
try: intercept_handlers.append(entry)
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
else: else:
self._fire_and_forget(entry, event_type, current_message) async_handlers.append(entry)
for entry in intercept_handlers:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
if continue_flag:
for entry in async_handlers:
async_message = current_message.deepcopy() if current_message else None
self._fire_and_forget(entry, event_type, async_message)
# 桥接到 IPC 插件运行时 # 桥接到 IPC 插件运行时
continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message) continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
@@ -138,7 +147,7 @@ class EventBus:
event_type: EventType | str, event_type: EventType | str,
message: Optional[MaiMessages], message: Optional[MaiMessages],
) -> None: ) -> None:
"""创建异步任务执行非拦截型 handler""" """创建异步任务执行非拦截型 handler"""
try: try:
task = asyncio.create_task(entry.handler(message)) task = asyncio.create_task(entry.handler(message))
task.set_name(entry.name) task.set_name(entry.name)
@@ -179,7 +188,7 @@ class EventBus:
return continue_flag, message return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type) event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None message_dict = message.to_transport_dict() if message else None
new_continue, modified_dict = await prm.bridge_event( new_continue, modified_dict = await prm.bridge_event(
event_type_value=event_value, event_type_value=event_value,
@@ -197,12 +206,7 @@ class EventBus:
@staticmethod @staticmethod
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages: def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
"""将 IPC 返回的消息字典回写到当前 MaiMessages。""" """将 IPC 返回的消息字典回写到当前 MaiMessages。"""
updated_message = message.deepcopy() return message.apply_transport_update(modified_dict)
valid_fields = {field.name for field in fields(MaiMessages)}
for key, value in modified_dict.items():
if key in valid_fields:
setattr(updated_message, key, value)
return updated_message
class _HandlerEntry: class _HandlerEntry:

View File

@@ -1,8 +1,8 @@
import copy import copy
import warnings import warnings
from dataclasses import dataclass, field, fields
from enum import Enum from enum import Enum
from typing import Dict, Any, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from maim_message import Seg from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
@@ -313,6 +313,86 @@ class MaiMessages:
def deepcopy(self): def deepcopy(self):
return copy.deepcopy(self) return copy.deepcopy(self)
def to_transport_dict(self) -> Dict[str, Any]:
"""将消息转换为可通过 IPC 传输的纯字典。"""
return {
field_info.name: self._serialize_transport_value(getattr(self, field_info.name))
for field_info in fields(MaiMessages)
if field_info.name != "_modify_flags"
}
def apply_transport_update(self, modified_dict: Dict[str, Any]) -> "MaiMessages":
"""将 IPC 返回的消息字典回写到当前消息对象。"""
updated_message = self.deepcopy()
valid_fields = {field_info.name for field_info in fields(MaiMessages) if field_info.name != "_modify_flags"}
for key, value in modified_dict.items():
if key not in valid_fields:
continue
deserialized = self._deserialize_transport_field(key, value)
setattr(updated_message, key, deserialized)
if key == "message_segments":
updated_message._modify_flags.modify_message_segments = True
elif key == "plain_text":
updated_message._modify_flags.modify_plain_text = True
elif key == "llm_prompt":
updated_message._modify_flags.modify_llm_prompt = True
elif key == "llm_response_content":
updated_message._modify_flags.modify_llm_response_content = True
elif key == "llm_response_reasoning":
updated_message._modify_flags.modify_llm_response_reasoning = True
return updated_message
@staticmethod
def _serialize_transport_value(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, list):
return [MaiMessages._serialize_transport_value(item) for item in value]
if isinstance(value, tuple):
return [MaiMessages._serialize_transport_value(item) for item in value]
if isinstance(value, dict):
return {key: MaiMessages._serialize_transport_value(item) for key, item in value.items()}
if hasattr(value, "__dict__"):
return {
key: MaiMessages._serialize_transport_value(item)
for key, item in vars(value).items()
if not key.startswith("_")
}
return value
@staticmethod
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
if field_name == "message_segments" and isinstance(value, list):
deserialized_segments: List[Seg] = []
for segment in value:
if isinstance(segment, Seg):
deserialized_segments.append(segment)
elif isinstance(segment, dict) and "type" in segment:
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
return deserialized_segments
if field_name == "llm_response_tool_call" and isinstance(value, list):
deserialized_tool_calls: List[ToolCall] = []
for tool_call in value:
if isinstance(tool_call, ToolCall):
deserialized_tool_calls.append(tool_call)
elif isinstance(tool_call, dict):
deserialized_tool_calls.append(
ToolCall(
call_id=str(tool_call.get("call_id", "")),
func_name=str(tool_call.get("func_name", "")),
args=tool_call.get("args"),
)
)
return deserialized_tool_calls
return value
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False): def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
""" """
修改消息段列表 修改消息段列表

View File

@@ -1,12 +1,33 @@
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration") logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import RegisteredComponent
from src.plugin_runtime.host.supervisor import PluginSupervisor
class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[str]) -> Dict[str, List[str]]: ...
def _iter_plugin_dirs(self) -> Iterable[str]: ...
class RuntimeComponentCapabilityMixin: class RuntimeComponentCapabilityMixin:
async def _cap_component_get_all_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
for sv in self.supervisors: for sv in self.supervisors:
for pid, reg in sv._registered_plugins.items(): for pid, reg in sv._registered_plugins.items():
@@ -34,7 +55,9 @@ class RuntimeComponentCapabilityMixin:
} }
return {"success": True, "plugins": result} return {"success": True, "plugins": result}
async def _cap_component_get_plugin_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_get_plugin_info(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", plugin_id) plugin_name: str = args.get("plugin_name", plugin_id)
try: try:
sv = self._get_supervisor_for_plugin(plugin_name) sv = self._get_supervisor_for_plugin(plugin_name)
@@ -54,20 +77,26 @@ class RuntimeComponentCapabilityMixin:
} }
return {"success": False, "error": f"未找到插件: {plugin_name}"} return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_list_loaded_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_list_loaded_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugins: List[str] = [] plugins: List[str] = []
for sv in self.supervisors: for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys()) plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins} return {"success": True, "plugins": plugins}
async def _cap_component_list_registered_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_list_registered_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugins: List[str] = [] plugins: List[str] = []
for sv in self.supervisors: for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys()) plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins} return {"success": True, "plugins": plugins}
def _resolve_component_toggle_target(self, name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]: def _resolve_component_toggle_target(
short_name_matches: List[Any] = [] self: _RuntimeComponentManagerProtocol, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
short_name_matches: List["RegisteredComponent"] = []
for sv in self.supervisors: for sv in self.supervisors:
comp = sv.component_registry.get_component(name) comp = sv.component_registry.get_component(name)
if comp is not None and comp.component_type == component_type: if comp is not None and comp.component_type == component_type:
@@ -85,7 +114,9 @@ class RuntimeComponentCapabilityMixin:
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({component_type})" return None, f"未找到组件: {name} ({component_type})"
async def _cap_component_enable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "") name: str = args.get("name", "")
component_type: str = args.get("component_type", "") component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global") scope: str = args.get("scope", "global")
@@ -102,7 +133,9 @@ class RuntimeComponentCapabilityMixin:
comp.enabled = True comp.enabled = True
return {"success": True} return {"success": True}
async def _cap_component_disable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_disable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "") name: str = args.get("name", "")
component_type: str = args.get("component_type", "") component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global") scope: str = args.get("scope", "global")
@@ -119,7 +152,9 @@ class RuntimeComponentCapabilityMixin:
comp.enabled = False comp.enabled = False
return {"success": True} return {"success": True}
async def _cap_component_load_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_load_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", "") plugin_name: str = args.get("plugin_name", "")
if not plugin_name: if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"} return {"success": False, "error": "缺少必要参数 plugin_name"}
@@ -162,10 +197,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"未找到插件: {plugin_name}"} return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_unload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"} return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
async def _cap_component_reload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: async def _cap_component_reload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", "") plugin_name: str = args.get("plugin_name", "")
if not plugin_name: if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"} return {"success": False, "error": "缺少必要参数 plugin_name"}

View File

@@ -1,7 +1,8 @@
from typing import Any, Dict from typing import Any, Dict
from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.payload_content.tool_option import ToolCall
logger = get_logger("plugin_runtime.integration") logger = get_logger("plugin_runtime.integration")
@@ -196,9 +197,12 @@ class RuntimeCoreCapabilityMixin:
serialized_tool_calls = None serialized_tool_calls = None
if tool_calls: if tool_calls:
serialized_tool_calls = [ serialized_tool_calls = [
{"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} {
for tc in tool_calls "id": tool_call.call_id,
if hasattr(tc, "function") "function": {"name": tool_call.func_name, "arguments": tool_call.args or {}},
}
for tool_call in tool_calls
if isinstance(tool_call, ToolCall)
] ]
return { return {
"success": success, "success": success,

View File

@@ -1,7 +1,7 @@
import random
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import random
import time import time
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
@@ -260,8 +260,8 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
def _serialize_messages(messages: list) -> List[Dict[str, Any]]: def _serialize_messages(messages: list) -> List[Any]:
result: List[Dict[str, Any]] = [] result: List[Any] = []
for msg in messages: for msg in messages:
if hasattr(msg, "model_dump"): if hasattr(msg, "model_dump"):
result.append(msg.model_dump()) result.append(msg.model_dump())

View File

@@ -90,23 +90,38 @@ class EventDispatcher:
should_continue = True should_continue = True
modified_message: Optional[Dict[str, Any]] = None modified_message: Optional[Dict[str, Any]] = None
intercept_handlers: List[RegisteredComponent] = []
async_handlers: List[RegisteredComponent] = []
for handler in handlers: for handler in handlers:
intercept = handler.metadata.get("intercept_message", False) if handler.metadata.get("intercept_message", False):
intercept_handlers.append(handler)
else:
async_handlers.append(handler)
for handler in intercept_handlers:
args = { args = {
"event_type": event_type, "event_type": event_type,
"message": modified_message or message, "message": modified_message or message,
**(extra_args or {}), **(extra_args or {}),
} }
if intercept: result = await self._invoke_handler(invoke_fn, handler, args, event_type)
# 阻塞执行 if result and not result.continue_processing:
result = await self._invoke_handler(invoke_fn, handler, args, event_type) should_continue = False
if result and not result.continue_processing: break
should_continue = False if result and result.modified_message:
if result and result.modified_message: modified_message = result.modified_message
modified_message = result.modified_message
else: if should_continue:
final_message = modified_message or message
for handler in async_handlers:
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收 # 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type)) task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
self._background_tasks.add(task) self._background_tasks.add(task)

View File

@@ -7,12 +7,12 @@
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
""" """
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple
import asyncio import asyncio
import json import json
import os import os
from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import config_manager, global_config from src.config.config import config_manager, global_config
@@ -219,12 +219,11 @@ class PluginRuntimeManager(
if not self._started: if not self._started:
return return
tasks = [ if tasks := [
self.notify_plugin_config_updated(plugin_id) self.notify_plugin_config_updated(plugin_id)
for sv in self.supervisors for sv in self.supervisors
for plugin_id in list(sv._registered_plugins.keys()) for plugin_id in list(sv._registered_plugins.keys())
] ]:
if tasks:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
# ─── 事件桥接 ────────────────────────────────────────────── # ─── 事件桥接 ──────────────────────────────────────────────
@@ -393,7 +392,7 @@ class PluginRuntimeManager(
for supervisor in self.supervisors: for supervisor in self.supervisors:
yield from getattr(supervisor, "_plugin_dirs", []) yield from getattr(supervisor, "_plugin_dirs", [])
async def _handle_plugin_file_changes(self, changes: List[FileChange]) -> None: async def _handle_plugin_file_changes(self, changes: Sequence[FileChange]) -> None:
if not self._started or not changes: if not self._started or not changes:
return return
@@ -405,7 +404,7 @@ class PluginRuntimeManager(
return return
reload_supervisors: List[Any] = [] reload_supervisors: List[Any] = []
config_updates: Dict[str, set[str]] = {} config_updates: Dict[int, set[str]] = {}
changed_paths = [change.path.resolve() for change in changes] changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors: for supervisor in self.supervisors: