重构事件总线和插件运行时,优化消息处理逻辑,新增 IPC 传输字典转换功能,改进组件管理协议
This commit is contained in:
@@ -9,7 +9,6 @@
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from dataclasses import fields
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -101,20 +100,30 @@ class EventBus:
|
||||
|
||||
continue_flag = True
|
||||
current_message = message.deepcopy() if message else None
|
||||
intercept_handlers: List[_HandlerEntry] = []
|
||||
async_handlers: List[_HandlerEntry] = []
|
||||
|
||||
for entry in handlers:
|
||||
if entry.intercept:
|
||||
try:
|
||||
should_continue, modified = await entry.handler(current_message)
|
||||
if modified is not None:
|
||||
current_message = modified
|
||||
if not should_continue:
|
||||
continue_flag = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
|
||||
intercept_handlers.append(entry)
|
||||
else:
|
||||
self._fire_and_forget(entry, event_type, current_message)
|
||||
async_handlers.append(entry)
|
||||
|
||||
for entry in intercept_handlers:
|
||||
try:
|
||||
should_continue, modified = await entry.handler(current_message)
|
||||
if modified is not None:
|
||||
current_message = modified
|
||||
if not should_continue:
|
||||
continue_flag = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True)
|
||||
|
||||
if continue_flag:
|
||||
for entry in async_handlers:
|
||||
async_message = current_message.deepcopy() if current_message else None
|
||||
self._fire_and_forget(entry, event_type, async_message)
|
||||
|
||||
# 桥接到 IPC 插件运行时
|
||||
continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message)
|
||||
@@ -138,7 +147,7 @@ class EventBus:
|
||||
event_type: EventType | str,
|
||||
message: Optional[MaiMessages],
|
||||
) -> None:
|
||||
"""创建异步任务执行非拦截型 handler"""
|
||||
"""创建异步任务执行非拦截型 handler。"""
|
||||
try:
|
||||
task = asyncio.create_task(entry.handler(message))
|
||||
task.set_name(entry.name)
|
||||
@@ -179,7 +188,7 @@ class EventBus:
|
||||
return continue_flag, message
|
||||
|
||||
event_value = event_type.value if isinstance(event_type, EventType) else str(event_type)
|
||||
message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None
|
||||
message_dict = message.to_transport_dict() if message else None
|
||||
|
||||
new_continue, modified_dict = await prm.bridge_event(
|
||||
event_type_value=event_value,
|
||||
@@ -197,12 +206,7 @@ class EventBus:
|
||||
@staticmethod
|
||||
def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages:
|
||||
"""将 IPC 返回的消息字典回写到当前 MaiMessages。"""
|
||||
updated_message = message.deepcopy()
|
||||
valid_fields = {field.name for field in fields(MaiMessages)}
|
||||
for key, value in modified_dict.items():
|
||||
if key in valid_fields:
|
||||
setattr(updated_message, key, value)
|
||||
return updated_message
|
||||
return message.apply_transport_update(modified_dict)
|
||||
|
||||
|
||||
class _HandlerEntry:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import copy
|
||||
import warnings
|
||||
from dataclasses import dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from maim_message import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
@@ -313,6 +313,86 @@ class MaiMessages:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def to_transport_dict(self) -> Dict[str, Any]:
|
||||
"""将消息转换为可通过 IPC 传输的纯字典。"""
|
||||
return {
|
||||
field_info.name: self._serialize_transport_value(getattr(self, field_info.name))
|
||||
for field_info in fields(MaiMessages)
|
||||
if field_info.name != "_modify_flags"
|
||||
}
|
||||
|
||||
def apply_transport_update(self, modified_dict: Dict[str, Any]) -> "MaiMessages":
|
||||
"""将 IPC 返回的消息字典回写到当前消息对象。"""
|
||||
updated_message = self.deepcopy()
|
||||
valid_fields = {field_info.name for field_info in fields(MaiMessages) if field_info.name != "_modify_flags"}
|
||||
|
||||
for key, value in modified_dict.items():
|
||||
if key not in valid_fields:
|
||||
continue
|
||||
deserialized = self._deserialize_transport_field(key, value)
|
||||
setattr(updated_message, key, deserialized)
|
||||
|
||||
if key == "message_segments":
|
||||
updated_message._modify_flags.modify_message_segments = True
|
||||
elif key == "plain_text":
|
||||
updated_message._modify_flags.modify_plain_text = True
|
||||
elif key == "llm_prompt":
|
||||
updated_message._modify_flags.modify_llm_prompt = True
|
||||
elif key == "llm_response_content":
|
||||
updated_message._modify_flags.modify_llm_response_content = True
|
||||
elif key == "llm_response_reasoning":
|
||||
updated_message._modify_flags.modify_llm_response_reasoning = True
|
||||
|
||||
return updated_message
|
||||
|
||||
@staticmethod
|
||||
def _serialize_transport_value(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, list):
|
||||
return [MaiMessages._serialize_transport_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [MaiMessages._serialize_transport_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: MaiMessages._serialize_transport_value(item) for key, item in value.items()}
|
||||
if hasattr(value, "__dict__"):
|
||||
return {
|
||||
key: MaiMessages._serialize_transport_value(item)
|
||||
for key, item in vars(value).items()
|
||||
if not key.startswith("_")
|
||||
}
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_transport_field(field_name: str, value: Any) -> Any:
|
||||
if field_name == "message_segments" and isinstance(value, list):
|
||||
deserialized_segments: List[Seg] = []
|
||||
for segment in value:
|
||||
if isinstance(segment, Seg):
|
||||
deserialized_segments.append(segment)
|
||||
elif isinstance(segment, dict) and "type" in segment:
|
||||
deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data")))
|
||||
return deserialized_segments
|
||||
|
||||
if field_name == "llm_response_tool_call" and isinstance(value, list):
|
||||
deserialized_tool_calls: List[ToolCall] = []
|
||||
for tool_call in value:
|
||||
if isinstance(tool_call, ToolCall):
|
||||
deserialized_tool_calls.append(tool_call)
|
||||
elif isinstance(tool_call, dict):
|
||||
deserialized_tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=str(tool_call.get("call_id", "")),
|
||||
func_name=str(tool_call.get("func_name", "")),
|
||||
args=tool_call.get("args"),
|
||||
)
|
||||
)
|
||||
return deserialized_tool_calls
|
||||
|
||||
return value
|
||||
|
||||
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
|
||||
"""
|
||||
修改消息段列表
|
||||
|
||||
@@ -1,12 +1,33 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_runtime.host.component_registry import RegisteredComponent
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
|
||||
class _RuntimeComponentManagerProtocol(Protocol):
|
||||
@property
|
||||
def supervisors(self) -> List["PluginSupervisor"]: ...
|
||||
|
||||
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
|
||||
|
||||
def _resolve_component_toggle_target(
|
||||
self, name: str, component_type: str
|
||||
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
|
||||
|
||||
def _find_duplicate_plugin_ids(self, plugin_dirs: List[str]) -> Dict[str, List[str]]: ...
|
||||
|
||||
def _iter_plugin_dirs(self) -> Iterable[str]: ...
|
||||
|
||||
|
||||
class RuntimeComponentCapabilityMixin:
|
||||
async def _cap_component_get_all_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_get_all_plugins(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
result: Dict[str, Any] = {}
|
||||
for sv in self.supervisors:
|
||||
for pid, reg in sv._registered_plugins.items():
|
||||
@@ -34,7 +55,9 @@ class RuntimeComponentCapabilityMixin:
|
||||
}
|
||||
return {"success": True, "plugins": result}
|
||||
|
||||
async def _cap_component_get_plugin_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_get_plugin_info(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
plugin_name: str = args.get("plugin_name", plugin_id)
|
||||
try:
|
||||
sv = self._get_supervisor_for_plugin(plugin_name)
|
||||
@@ -54,20 +77,26 @@ class RuntimeComponentCapabilityMixin:
|
||||
}
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
|
||||
async def _cap_component_list_loaded_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_list_loaded_plugins(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
plugins: List[str] = []
|
||||
for sv in self.supervisors:
|
||||
plugins.extend(sv._registered_plugins.keys())
|
||||
return {"success": True, "plugins": plugins}
|
||||
|
||||
async def _cap_component_list_registered_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_list_registered_plugins(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
plugins: List[str] = []
|
||||
for sv in self.supervisors:
|
||||
plugins.extend(sv._registered_plugins.keys())
|
||||
return {"success": True, "plugins": plugins}
|
||||
|
||||
def _resolve_component_toggle_target(self, name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]:
|
||||
short_name_matches: List[Any] = []
|
||||
def _resolve_component_toggle_target(
|
||||
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
|
||||
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
|
||||
short_name_matches: List["RegisteredComponent"] = []
|
||||
for sv in self.supervisors:
|
||||
comp = sv.component_registry.get_component(name)
|
||||
if comp is not None and comp.component_type == component_type:
|
||||
@@ -85,7 +114,9 @@ class RuntimeComponentCapabilityMixin:
|
||||
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
|
||||
return None, f"未找到组件: {name} ({component_type})"
|
||||
|
||||
async def _cap_component_enable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_enable(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
@@ -102,7 +133,9 @@ class RuntimeComponentCapabilityMixin:
|
||||
comp.enabled = True
|
||||
return {"success": True}
|
||||
|
||||
async def _cap_component_disable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_disable(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
@@ -119,7 +152,9 @@ class RuntimeComponentCapabilityMixin:
|
||||
comp.enabled = False
|
||||
return {"success": True}
|
||||
|
||||
async def _cap_component_load_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_load_plugin(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
plugin_name: str = args.get("plugin_name", "")
|
||||
if not plugin_name:
|
||||
return {"success": False, "error": "缺少必要参数 plugin_name"}
|
||||
@@ -162,10 +197,14 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
|
||||
async def _cap_component_unload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_unload_plugin(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
|
||||
|
||||
async def _cap_component_reload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
async def _cap_component_reload_plugin(
|
||||
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
|
||||
) -> Any:
|
||||
plugin_name: str = args.get("plugin_name", "")
|
||||
if not plugin_name:
|
||||
return {"success": False, "error": "缺少必要参数 plugin_name"}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
@@ -196,9 +197,12 @@ class RuntimeCoreCapabilityMixin:
|
||||
serialized_tool_calls = None
|
||||
if tool_calls:
|
||||
serialized_tool_calls = [
|
||||
{"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
|
||||
for tc in tool_calls
|
||||
if hasattr(tc, "function")
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {"name": tool_call.func_name, "arguments": tool_call.args or {}},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
if isinstance(tool_call, ToolCall)
|
||||
]
|
||||
return {
|
||||
"success": success,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
|
||||
@@ -260,8 +260,8 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_messages(messages: list) -> List[Dict[str, Any]]:
|
||||
result: List[Dict[str, Any]] = []
|
||||
def _serialize_messages(messages: list) -> List[Any]:
|
||||
result: List[Any] = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, "model_dump"):
|
||||
result.append(msg.model_dump())
|
||||
|
||||
@@ -90,23 +90,38 @@ class EventDispatcher:
|
||||
|
||||
should_continue = True
|
||||
modified_message: Optional[Dict[str, Any]] = None
|
||||
intercept_handlers: List[RegisteredComponent] = []
|
||||
async_handlers: List[RegisteredComponent] = []
|
||||
|
||||
for handler in handlers:
|
||||
intercept = handler.metadata.get("intercept_message", False)
|
||||
if handler.metadata.get("intercept_message", False):
|
||||
intercept_handlers.append(handler)
|
||||
else:
|
||||
async_handlers.append(handler)
|
||||
|
||||
for handler in intercept_handlers:
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": modified_message or message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
|
||||
if intercept:
|
||||
# 阻塞执行
|
||||
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
if result and not result.continue_processing:
|
||||
should_continue = False
|
||||
if result and result.modified_message:
|
||||
modified_message = result.modified_message
|
||||
else:
|
||||
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
if result and not result.continue_processing:
|
||||
should_continue = False
|
||||
break
|
||||
if result and result.modified_message:
|
||||
modified_message = result.modified_message
|
||||
|
||||
if should_continue:
|
||||
final_message = modified_message or message
|
||||
for handler in async_handlers:
|
||||
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
|
||||
args = {
|
||||
"event_type": event_type,
|
||||
"message": async_message,
|
||||
**(extra_args or {}),
|
||||
}
|
||||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||||
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
|
||||
self._background_tasks.add(task)
|
||||
|
||||
@@ -7,12 +7,12 @@
|
||||
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager, global_config
|
||||
@@ -219,12 +219,11 @@ class PluginRuntimeManager(
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
tasks = [
|
||||
if tasks := [
|
||||
self.notify_plugin_config_updated(plugin_id)
|
||||
for sv in self.supervisors
|
||||
for plugin_id in list(sv._registered_plugins.keys())
|
||||
]
|
||||
if tasks:
|
||||
]:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# ─── 事件桥接 ──────────────────────────────────────────────
|
||||
@@ -393,7 +392,7 @@ class PluginRuntimeManager(
|
||||
for supervisor in self.supervisors:
|
||||
yield from getattr(supervisor, "_plugin_dirs", [])
|
||||
|
||||
async def _handle_plugin_file_changes(self, changes: List[FileChange]) -> None:
|
||||
async def _handle_plugin_file_changes(self, changes: Sequence[FileChange]) -> None:
|
||||
if not self._started or not changes:
|
||||
return
|
||||
|
||||
@@ -405,7 +404,7 @@ class PluginRuntimeManager(
|
||||
return
|
||||
|
||||
reload_supervisors: List[Any] = []
|
||||
config_updates: Dict[str, set[str]] = {}
|
||||
config_updates: Dict[int, set[str]] = {}
|
||||
changed_paths = [change.path.resolve() for change in changes]
|
||||
|
||||
for supervisor in self.supervisors:
|
||||
|
||||
Reference in New Issue
Block a user