From 2c330e3902454e9f0818cffed6c8bb7f6a939bdd Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 14 Mar 2026 01:39:59 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BA=8B=E4=BB=B6=E6=80=BB?= =?UTF-8?q?=E7=BA=BF=E5=92=8C=E6=8F=92=E4=BB=B6=E8=BF=90=E8=A1=8C=E6=97=B6?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E6=96=B0=E5=A2=9E=20IPC=20=E4=BC=A0?= =?UTF-8?q?=E8=BE=93=E5=AD=97=E5=85=B8=E8=BD=AC=E6=8D=A2=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E6=94=B9=E8=BF=9B=E7=BB=84=E4=BB=B6=E7=AE=A1=E7=90=86?= =?UTF-8?q?=E5=8D=8F=E8=AE=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/event_bus.py | 42 +++++----- src/core/types.py | 84 ++++++++++++++++++- src/plugin_runtime/capabilities/components.py | 63 +++++++++++--- src/plugin_runtime/capabilities/core.py | 12 ++- src/plugin_runtime/capabilities/data.py | 6 +- src/plugin_runtime/host/event_dispatcher.py | 33 ++++++-- src/plugin_runtime/integration.py | 13 ++- 7 files changed, 197 insertions(+), 56 deletions(-) diff --git a/src/core/event_bus.py b/src/core/event_bus.py index 7d93d04b..a803ccec 100644 --- a/src/core/event_bus.py +++ b/src/core/event_bus.py @@ -9,7 +9,6 @@ import asyncio import contextlib -from dataclasses import fields from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from src.common.logger import get_logger @@ -101,20 +100,30 @@ class EventBus: continue_flag = True current_message = message.deepcopy() if message else None + intercept_handlers: List[_HandlerEntry] = [] + async_handlers: List[_HandlerEntry] = [] for entry in handlers: if entry.intercept: - try: - should_continue, modified = await entry.handler(current_message) - if modified is not None: - current_message = modified - if not should_continue: - continue_flag = False - break - except Exception as e: - logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True) + intercept_handlers.append(entry) else: - self._fire_and_forget(entry, event_type, current_message) + async_handlers.append(entry) + + for entry in intercept_handlers: + try: + should_continue, modified = await entry.handler(current_message) + if modified is not None: + current_message = modified + if not should_continue: + continue_flag = False + break + except Exception as e: + logger.error(f"拦截型 handler {entry.name} 执行异常: {e}", exc_info=True) + + if continue_flag: + for entry in async_handlers: + async_message = current_message.deepcopy() if current_message else None + self._fire_and_forget(entry, event_type, async_message) # 桥接到 IPC 插件运行时 continue_flag, current_message = await self._bridge_to_ipc_runtime(event_type, continue_flag, current_message) @@ -138,7 +147,7 @@ class EventBus: event_type: EventType | str, message: Optional[MaiMessages], ) -> None: - """创建异步任务执行非拦截型 handler""" + """创建异步任务执行非拦截型 handler。""" try: task = asyncio.create_task(entry.handler(message)) task.set_name(entry.name) @@ -179,7 +188,7 @@ class EventBus: return continue_flag, message event_value = event_type.value if isinstance(event_type, EventType) else str(event_type) - message_dict = message.to_dict() if message and hasattr(message, "to_dict") else None + message_dict = message.to_transport_dict() if message else None new_continue, modified_dict = await prm.bridge_event( event_type_value=event_value, @@ -197,12 +206,7 @@ class EventBus: @staticmethod def _apply_ipc_message_update(message: MaiMessages, modified_dict: Dict[str, Any]) -> MaiMessages: """将 IPC 返回的消息字典回写到当前 MaiMessages。""" - updated_message = message.deepcopy() - valid_fields = {field.name for field in fields(MaiMessages)} - for key, value in modified_dict.items(): - if key in valid_fields: - setattr(updated_message, key, value) - return updated_message + return message.apply_transport_update(modified_dict) class _HandlerEntry: diff --git a/src/core/types.py b/src/core/types.py index bfe0834a..535352f3 100644 --- a/src/core/types.py +++ b/src/core/types.py @@ -1,8 +1,8 @@ import copy import warnings +from dataclasses import dataclass, field, fields from enum import Enum -from typing import Dict, Any, List, Optional, Tuple -from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple from maim_message import Seg from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType @@ -313,6 +313,86 @@ class MaiMessages: def deepcopy(self): return copy.deepcopy(self) + def to_transport_dict(self) -> Dict[str, Any]: + """将消息转换为可通过 IPC 传输的纯字典。""" + return { + field_info.name: self._serialize_transport_value(getattr(self, field_info.name)) + for field_info in fields(MaiMessages) + if field_info.name != "_modify_flags" + } + + def apply_transport_update(self, modified_dict: Dict[str, Any]) -> "MaiMessages": + """将 IPC 返回的消息字典回写到当前消息对象。""" + updated_message = self.deepcopy() + valid_fields = {field_info.name for field_info in fields(MaiMessages) if field_info.name != "_modify_flags"} + + for key, value in modified_dict.items(): + if key not in valid_fields: + continue + deserialized = self._deserialize_transport_field(key, value) + setattr(updated_message, key, deserialized) + + if key == "message_segments": + updated_message._modify_flags.modify_message_segments = True + elif key == "plain_text": + updated_message._modify_flags.modify_plain_text = True + elif key == "llm_prompt": + updated_message._modify_flags.modify_llm_prompt = True + elif key == "llm_response_content": + updated_message._modify_flags.modify_llm_response_content = True + elif key == "llm_response_reasoning": + updated_message._modify_flags.modify_llm_response_reasoning = True + + return updated_message + + @staticmethod + def _serialize_transport_value(value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, Enum): + return value.value + if isinstance(value, list): + return [MaiMessages._serialize_transport_value(item) for item in value] + if isinstance(value, tuple): + return [MaiMessages._serialize_transport_value(item) for item in value] + if isinstance(value, dict): + return {key: MaiMessages._serialize_transport_value(item) for key, item in value.items()} + if hasattr(value, "__dict__"): + return { + key: MaiMessages._serialize_transport_value(item) + for key, item in vars(value).items() + if not key.startswith("_") + } + return value + + @staticmethod + def _deserialize_transport_field(field_name: str, value: Any) -> Any: + if field_name == "message_segments" and isinstance(value, list): + deserialized_segments: List[Seg] = [] + for segment in value: + if isinstance(segment, Seg): + deserialized_segments.append(segment) + elif isinstance(segment, dict) and "type" in segment: + deserialized_segments.append(Seg(type=segment.get("type", "text"), data=segment.get("data"))) + return deserialized_segments + + if field_name == "llm_response_tool_call" and isinstance(value, list): + deserialized_tool_calls: List[ToolCall] = [] + for tool_call in value: + if isinstance(tool_call, ToolCall): + deserialized_tool_calls.append(tool_call) + elif isinstance(tool_call, dict): + deserialized_tool_calls.append( + ToolCall( + call_id=str(tool_call.get("call_id", "")), + func_name=str(tool_call.get("func_name", "")), + args=tool_call.get("args"), + ) + ) + return deserialized_tool_calls + + return value + def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False): """ 修改消息段列表 diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index c5d3d20f..61bb0e39 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -1,12 +1,33 @@ -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol from src.common.logger import get_logger logger = get_logger("plugin_runtime.integration") +if TYPE_CHECKING: + from src.plugin_runtime.host.component_registry import RegisteredComponent + from src.plugin_runtime.host.supervisor import PluginSupervisor + + +class _RuntimeComponentManagerProtocol(Protocol): + @property + def supervisors(self) -> List["PluginSupervisor"]: ... + + def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ... + + def _resolve_component_toggle_target( + self, name: str, component_type: str + ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ... + + def _find_duplicate_plugin_ids(self, plugin_dirs: List[str]) -> Dict[str, List[str]]: ... + + def _iter_plugin_dirs(self) -> Iterable[str]: ... + class RuntimeComponentCapabilityMixin: - async def _cap_component_get_all_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_get_all_plugins( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: result: Dict[str, Any] = {} for sv in self.supervisors: for pid, reg in sv._registered_plugins.items(): @@ -34,7 +55,9 @@ class RuntimeComponentCapabilityMixin: } return {"success": True, "plugins": result} - async def _cap_component_get_plugin_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_get_plugin_info( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: plugin_name: str = args.get("plugin_name", plugin_id) try: sv = self._get_supervisor_for_plugin(plugin_name) @@ -54,20 +77,26 @@ class RuntimeComponentCapabilityMixin: } return {"success": False, "error": f"未找到插件: {plugin_name}"} - async def _cap_component_list_loaded_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_list_loaded_plugins( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: plugins: List[str] = [] for sv in self.supervisors: plugins.extend(sv._registered_plugins.keys()) return {"success": True, "plugins": plugins} - async def _cap_component_list_registered_plugins(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_list_registered_plugins( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: plugins: List[str] = [] for sv in self.supervisors: plugins.extend(sv._registered_plugins.keys()) return {"success": True, "plugins": plugins} - def _resolve_component_toggle_target(self, name: str, component_type: str) -> tuple[Optional[Any], Optional[str]]: - short_name_matches: List[Any] = [] + def _resolve_component_toggle_target( + self: _RuntimeComponentManagerProtocol, name: str, component_type: str + ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: + short_name_matches: List["RegisteredComponent"] = [] for sv in self.supervisors: comp = sv.component_registry.get_component(name) if comp is not None and comp.component_type == component_type: @@ -85,7 +114,9 @@ class RuntimeComponentCapabilityMixin: return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" return None, f"未找到组件: {name} ({component_type})" - async def _cap_component_enable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_enable( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") scope: str = args.get("scope", "global") @@ -102,7 +133,9 @@ class RuntimeComponentCapabilityMixin: comp.enabled = True return {"success": True} - async def _cap_component_disable(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_disable( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") scope: str = args.get("scope", "global") @@ -119,7 +152,9 @@ class RuntimeComponentCapabilityMixin: comp.enabled = False return {"success": True} - async def _cap_component_load_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_load_plugin( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: plugin_name: str = args.get("plugin_name", "") if not plugin_name: return {"success": False, "error": "缺少必要参数 plugin_name"} @@ -162,10 +197,14 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": f"未找到插件: {plugin_name}"} - async def _cap_component_unload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_unload_plugin( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"} - async def _cap_component_reload_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any: + async def _cap_component_reload_plugin( + self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] + ) -> Any: plugin_name: str = args.get("plugin_name", "") if not plugin_name: return {"success": False, "error": "缺少必要参数 plugin_name"} diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py index 8bff6d00..def5f03d 100644 --- a/src/plugin_runtime/capabilities/core.py +++ b/src/plugin_runtime/capabilities/core.py @@ -1,7 +1,8 @@ from typing import Any, Dict -from src.config.config import global_config from src.common.logger import get_logger +from src.config.config import global_config +from src.llm_models.payload_content.tool_option import ToolCall logger = get_logger("plugin_runtime.integration") @@ -196,9 +197,12 @@ class RuntimeCoreCapabilityMixin: serialized_tool_calls = None if tool_calls: serialized_tool_calls = [ - {"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} - for tc in tool_calls - if hasattr(tc, "function") + { + "id": tool_call.call_id, + "function": {"name": tool_call.func_name, "arguments": tool_call.args or {}}, + } + for tool_call in tool_calls + if isinstance(tool_call, ToolCall) ] return { "success": success, diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index b77a3de1..c4ae0a56 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -1,7 +1,7 @@ -import random from pathlib import Path from typing import Any, Dict, List, Optional +import random import time from src.chat.message_receive.chat_manager import BotChatSession, chat_manager @@ -260,8 +260,8 @@ class RuntimeDataCapabilityMixin: return {"success": False, "error": str(e)} @staticmethod - def _serialize_messages(messages: list) -> List[Dict[str, Any]]: - result: List[Dict[str, Any]] = [] + def _serialize_messages(messages: list) -> List[Any]: + result: List[Any] = [] for msg in messages: if hasattr(msg, "model_dump"): result.append(msg.model_dump()) diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py index 48dfcedd..720e93d7 100644 --- a/src/plugin_runtime/host/event_dispatcher.py +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -90,23 +90,38 @@ class EventDispatcher: should_continue = True modified_message: Optional[Dict[str, Any]] = None + intercept_handlers: List[RegisteredComponent] = [] + async_handlers: List[RegisteredComponent] = [] for handler in handlers: - intercept = handler.metadata.get("intercept_message", False) + if handler.metadata.get("intercept_message", False): + intercept_handlers.append(handler) + else: + async_handlers.append(handler) + + for handler in intercept_handlers: args = { "event_type": event_type, "message": modified_message or message, **(extra_args or {}), } - if intercept: - # 阻塞执行 - result = await self._invoke_handler(invoke_fn, handler, args, event_type) - if result and not result.continue_processing: - should_continue = False - if result and result.modified_message: - modified_message = result.modified_message - else: + result = await self._invoke_handler(invoke_fn, handler, args, event_type) + if result and not result.continue_processing: + should_continue = False + break + if result and result.modified_message: + modified_message = result.modified_message + + if should_continue: + final_message = modified_message or message + for handler in async_handlers: + async_message = final_message.copy() if isinstance(final_message, dict) else final_message + args = { + "event_type": event_type, + "message": async_message, + **(extra_args or {}), + } # 非阻塞:保持实例级强引用,防止 task 被 GC 回收 task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type)) self._background_tasks.add(task) diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index a6d04d95..125aad8d 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -7,12 +7,12 @@ 4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 """ -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple import asyncio import json import os -from pathlib import Path from src.common.logger import get_logger from src.config.config import config_manager, global_config @@ -219,12 +219,11 @@ class PluginRuntimeManager( if not self._started: return - tasks = [ + if tasks := [ self.notify_plugin_config_updated(plugin_id) for sv in self.supervisors for plugin_id in list(sv._registered_plugins.keys()) - ] - if tasks: + ]: await asyncio.gather(*tasks, return_exceptions=True) # ─── 事件桥接 ────────────────────────────────────────────── @@ -393,7 +392,7 @@ class PluginRuntimeManager( for supervisor in self.supervisors: yield from getattr(supervisor, "_plugin_dirs", []) - async def _handle_plugin_file_changes(self, changes: List[FileChange]) -> None: + async def _handle_plugin_file_changes(self, changes: Sequence[FileChange]) -> None: if not self._started or not changes: return @@ -405,7 +404,7 @@ class PluginRuntimeManager( return reload_supervisors: List[Any] = [] - config_updates: Dict[str, set[str]] = {} + config_updates: Dict[int, set[str]] = {} changed_paths = [change.path.resolve() for change in changes] for supervisor in self.supervisors: