重构事件总线和插件运行时,优化消息处理逻辑,新增 IPC 传输字典转换功能,改进组件管理协议

This commit is contained in:
DrSmoothl
2026-03-14 01:39:59 +08:00
parent 84212e8e95
commit 2c330e3902
7 changed files with 197 additions and 56 deletions

View File

@@ -9,7 +9,6 @@
import asyncio
import contextlib
from dataclasses import fields
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -101,20 +100,30 @@ class EventBus:
continue_flag = True
current_message = message.deepcopy() if message else None
intercept_handlers: List[_HandlerEntry] = []
async_handlers: List[_HandlerEntry] = []
for entry in handlers:
if entry.intercept:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
intercept_handlers.append(entry)
else:
self._fire_and_forget(entry, event_type, current_message)
async_handlers.append(entry)
for entry in intercept_handlers:
try:
should_continue, modified = await entry.handler(current_message)
if modified is not None:
current_message = modified
if not should_continue:
continue_flag = False
break
except Exception as e:
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
if continue_flag:
for entry in async_handlers:
async_message = current_message.deepcopy() if current_message else None
self._fire_and_forget(entry, event_type, async_message)
# 桥接到 IPC 插件运行时
continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
@@ -138,7 +147,7 @@ class EventBus:
event_type: EventType | str,
message: Optional[MaiMessages],
) -> None:
"""创建异步任务执行非拦截型 handler"""
"""创建异步任务执行非拦截型 handler"""
try:
task = asyncio.create_task(entry.handler(message))
task.set_name(entry.name)
@@ -179,7 +188,7 @@ class EventBus:
return continue_flag, message
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
message_dict = message.to_transport_dict() if message else None
new_continue, modified_dict = await prm.bridge_event(
event_type_value=event_value,
@@ -197,12 +206,7 @@ class EventBus:
@staticmethod
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
updated_message = message.deepcopy()
valid_fields = {field.name for field in fields(MaiMessages)}
for key, value in modified_dict.items():
if key in valid_fields:
setattr(updated_message, key, value)
return updated_message
return message.apply_transport_update(modified_dict)
class _HandlerEntry:

View File

@@ -1,8 +1,8 @@
import copy
import warnings
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
@@ -313,6 +313,86 @@ class MaiMessages:
def deepcopy(self):
return copy.deepcopy(self)
def to_transport_dict(self) -> Dict[str, Any]:
"""将消息转换为可通过 IPC 传输的纯字典。"""
return {
field_info.name: self._serialize_transport_value(getattr(self, field_info.name))
for field_info in fields(MaiMessages)
if field_info.name != "_modify_flags"
}
def apply_transport_update(self, modified_dict: Dict[str, Any]) -> "MaiMessages":
"""将 IPC 返回的消息字典回写到当前消息对象。"""
updated_message = self.deepcopy()
valid_fields = {field_info.name for field_info in fields(MaiMessages) if field_info.name != "_modify_flags"}
for key, value in modified_dict.items():
if key not in valid_fields:
continue
deserialized = self._deserialize_transport_field(key, value)
setattr(updated_message, key, deserialized)
if key == "message_segments":
updated_message._modify_flags.modify_message_segments = True
elif key == "plain_text":
updated_message._modify_flags.modify_plain_text = True
elif key == "llm_prompt":
updated_message._modify_flags.modify_llm_prompt = True
elif key == "llm_response_content":
updated_message._modify_flags.modify_llm_response_content = True
elif key == "llm_response_reasoning":
updated_message._modify_flags.modify_llm_response_reasoning = True
return updated_message
@staticmethod
def _serialize_transport_value(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, list):
return [MaiMessages._serialize_transport_value(item) for item in value]
if isinstance(value, tuple):
return [MaiMessages._serialize_transport_value(item) for item in value]
if isinstance(value, dict):
return {key: MaiMessages._serialize_transport_value(item) for key, item in value.items()}
if hasattr(value, "__dict__"):
return {
key: MaiMessages._serialize_transport_value(item)
for key, item in vars(value).items()
if not key.startswith("_")
}
return value
@staticmethod
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
if field_name == "message_segments" and isinstance(value, list):
deserialized_segments: List[Seg] = []
for segment in value:
if isinstance(segment, Seg):
deserialized_segments.append(segment)
elif isinstance(segment, dict) and "type" in segment:
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
return deserialized_segments
if field_name == "llm_response_tool_call" and isinstance(value, list):
deserialized_tool_calls: List[ToolCall] = []
for tool_call in value:
if isinstance(tool_call, ToolCall):
deserialized_tool_calls.append(tool_call)
elif isinstance(tool_call, dict):
deserialized_tool_calls.append(
ToolCall(
call_id=str(tool_call.get("call_id", "")),
func_name=str(tool_call.get("func_name", "")),
args=tool_call.get("args"),
)
)
return deserialized_tool_calls
return value
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
"""
修改消息段列表

View File

@@ -1,12 +1,33 @@
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import RegisteredComponent
from src.plugin_runtime.host.supervisor import PluginSupervisor
class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[str]) -> Dict[str, List[str]]: ...
def _iter_plugin_dirs(self) -> Iterable[str]: ...
class RuntimeComponentCapabilityMixin:
async def _cap_component_get_all_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
result: Dict[str, Any] = {}
for sv in self.supervisors:
for pid, reg in sv._registered_plugins.items():
@@ -34,7 +55,9 @@ class RuntimeComponentCapabilityMixin:
}
return {"success": True, "plugins": result}
async def _cap_component_get_plugin_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_get_plugin_info(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
@@ -54,20 +77,26 @@ class RuntimeComponentCapabilityMixin:
}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_list_loaded_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_list_loaded_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugins: List[str] = []
for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins}
async def _cap_component_list_registered_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_list_registered_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugins: List[str] = []
for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins}
def _resolve_component_toggle_target(self, name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]:
short_name_matches: List[Any] = []
def _resolve_component_toggle_target(
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
short_name_matches: List["RegisteredComponent"] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
if comp is not None and comp.component_type == component_type:
@@ -85,7 +114,9 @@ class RuntimeComponentCapabilityMixin:
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({component_type})"
async def _cap_component_enable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global")
@@ -102,7 +133,9 @@ class RuntimeComponentCapabilityMixin:
comp.enabled = True
return {"success": True}
async def _cap_component_disable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_disable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global")
@@ -119,7 +152,9 @@ class RuntimeComponentCapabilityMixin:
comp.enabled = False
return {"success": True}
async def _cap_component_load_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_load_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", "")
if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"}
@@ -162,10 +197,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_unload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
async def _cap_component_reload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
async def _cap_component_reload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", "")
if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"}

View File

@@ -1,7 +1,8 @@
from typing import Any, Dict
from src.config.config import global_config
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.payload_content.tool_option import ToolCall
logger = get_logger("plugin_runtime.integration")
@@ -196,9 +197,12 @@ class RuntimeCoreCapabilityMixin:
serialized_tool_calls = None
if tool_calls:
serialized_tool_calls = [
{"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
for tc in tool_calls
if hasattr(tc, "function")
{
"id": tool_call.call_id,
"function": {"name": tool_call.func_name, "arguments": tool_call.args or {}},
}
for tool_call in tool_calls
if isinstance(tool_call, ToolCall)
]
return {
"success": success,

View File

@@ -1,7 +1,7 @@
import random
from pathlib import Path
from typing import Any, Dict, List, Optional
import random
import time
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
@@ -260,8 +260,8 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
@staticmethod
def _serialize_messages(messages: list) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
def _serialize_messages(messages: list) -> List[Any]:
result: List[Any] = []
for msg in messages:
if hasattr(msg, "model_dump"):
result.append(msg.model_dump())

View File

@@ -90,23 +90,38 @@ class EventDispatcher:
should_continue = True
modified_message: Optional[Dict[str, Any]] = None
intercept_handlers: List[RegisteredComponent] = []
async_handlers: List[RegisteredComponent] = []
for handler in handlers:
intercept = handler.metadata.get("intercept_message", False)
if handler.metadata.get("intercept_message", False):
intercept_handlers.append(handler)
else:
async_handlers.append(handler)
for handler in intercept_handlers:
args = {
"event_type": event_type,
"message": modified_message or message,
**(extra_args or {}),
}
if intercept:
# 阻塞执行
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
if result and not result.continue_processing:
should_continue = False
if result and result.modified_message:
modified_message = result.modified_message
else:
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
if result and result.modified_message:
modified_message = result.modified_message
if should_continue:
final_message = modified_message or message
for handler in async_handlers:
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
self._background_tasks.add(task)

View File

@@ -7,12 +7,12 @@
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
"""
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple
import asyncio
import json
import os
from pathlib import Path
from src.common.logger import get_logger
from src.config.config import config_manager, global_config
@@ -219,12 +219,11 @@ class PluginRuntimeManager(
if not self._started:
return
tasks = [
if tasks := [
self.notify_plugin_config_updated(plugin_id)
for sv in self.supervisors
for plugin_id in list(sv._registered_plugins.keys())
]
if tasks:
]:
await asyncio.gather(*tasks, return_exceptions=True)
# ─── 事件桥接 ──────────────────────────────────────────────
@@ -393,7 +392,7 @@ class PluginRuntimeManager(
for supervisor in self.supervisors:
yield from getattr(supervisor, "_plugin_dirs", [])
async def _handle_plugin_file_changes(self, changes: List[FileChange]) -> None:
async def _handle_plugin_file_changes(self, changes: Sequence[FileChange]) -> None:
if not self._started or not changes:
return
@@ -405,7 +404,7 @@ class PluginRuntimeManager(
return
reload_supervisors: List[Any] = []
config_updates: Dict[str, set[str]] = {}
config_updates: Dict[int, set[str]] = {}
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors: