增加了event_handler修改内容的方法

This commit is contained in:
UnCLAS-Prommer
2025-09-07 01:15:21 +08:00
parent 0811cff8bf
commit b636683fe4
10 changed files with 215 additions and 85 deletions

View File

@@ -37,19 +37,19 @@ class BaseEventHandler(ABC):
@abstractmethod
async def execute(
self, message: MaiMessages | None
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]:
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
"""执行事件处理的抽象方法,子类必须实现
Args:
message (MaiMessages | None): 事件消息对象当你注册的事件为ON_START和ON_STOP时message为None
Returns:
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果)
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")
@classmethod
def get_handler_info(cls) -> "EventHandlerInfo":
"""获取事件处理器的信息"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
# 从类属性读取名称,如果没有定义则使用类名自动生成S
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
if "." in name:
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")

View File

@@ -1,4 +1,5 @@
import copy
import warnings
from enum import Enum
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
@@ -7,6 +8,7 @@ from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
# 组件类型枚举
class ComponentType(Enum):
"""组件类型枚举"""
@@ -56,6 +58,7 @@ class EventType(Enum):
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
ON_MESSAGE_PRE_PROCESS = "on_message_pre_process"
ON_MESSAGE = "on_message"
ON_PLAN = "on_plan"
POST_LLM = "post_llm"
@@ -116,9 +119,9 @@ class ActionInfo(ComponentInfo):
action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
activation_type: ActionActivationType = ActionActivationType.ALWAYS
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.0
llm_judge_prompt: str = ""
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
@@ -154,7 +157,9 @@ class CommandInfo(ComponentInfo):
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
default_factory=list
) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):
@@ -233,6 +238,15 @@ class PluginInfo:
return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
class ModifyFlag:
modify_message_segments: bool = False
modify_plain_text: bool = False
modify_llm_prompt: bool = False
modify_llm_response_content: bool = False
modify_llm_response_reasoning: bool = False
@dataclass
class MaiMessages:
"""MaiM插件消息"""
@@ -263,31 +277,129 @@ class MaiMessages:
llm_response_content: Optional[str] = None
"""LLM响应内容"""
llm_response_reasoning: Optional[str] = None
"""LLM响应推理内容"""
llm_response_model: Optional[str] = None
"""LLM响应模型名称"""
llm_response_tool_call: Optional[List[ToolCall]] = None
"""LLM使用的工具调用"""
action_usage: Optional[List[str]] = None
"""使用的Action"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []
def deepcopy(self):
return copy.deepcopy(self)
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
"""
修改消息段列表
Warning:
在生成了plain_text的情况下调用此方法可能会导致plain_text内容与消息段不一致
Args:
new_segments (List[Seg]): 新的消息段列表
"""
if self.plain_text and not suppress_warning:
warnings.warn(
"修改消息段后plain_text可能与消息段内容不一致建议同时更新plain_text",
UserWarning,
stacklevel=2,
)
self.message_segments = new_segments
self._modify_flags.modify_message_segments = True
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
"""
修改LLM提示词
Warning:
在没有生成llm_prompt的情况下调用此方法可能会导致修改无效
Args:
new_prompt (str): 新的提示词内容
"""
if self.llm_prompt is None and not suppress_warning:
warnings.warn(
"当前llm_prompt为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.llm_prompt = new_prompt
self._modify_flags.modify_llm_prompt = True
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
"""
修改生成的plain_text内容
Warning:
在未生成plain_text的情况下调用此方法可能会导致plain_text为空或者修改无效
Args:
new_text (str): 新的纯文本内容
"""
if not self.plain_text and not suppress_warning:
warnings.warn(
"当前plain_text为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.plain_text = new_text
self._modify_flags.modify_plain_text = True
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
"""
修改生成的llm_response_content内容
Warning:
在未生成llm_response_content的情况下调用此方法可能会导致llm_response_content为空或者修改无效
Args:
new_content (str): 新的LLM响应内容
"""
if not self.llm_response_content and not suppress_warning:
warnings.warn(
"当前llm_response_content为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.llm_response_content = new_content
self._modify_flags.modify_llm_response_content = True
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
"""
修改生成的llm_response_reasoning内容
Warning:
在未生成llm_response_reasoning的情况下调用此方法可能会导致llm_response_reasoning为空或者修改无效
Args:
new_reasoning (str): 新的LLM响应推理内容
"""
if not self.llm_response_reasoning and not suppress_warning:
warnings.warn(
"当前llm_response_reasoning为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.llm_response_reasoning = new_reasoning
self._modify_flags.modify_llm_response_reasoning = True
@dataclass
class CustomEventHandlerResult:
message: str = ""
timestamp: float = 0.0
extra_info: Optional[Dict] = None
extra_info: Optional[Dict] = None

View File

@@ -71,7 +71,7 @@ class EventsManager:
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> bool:
) -> Tuple[bool, Optional[MaiMessages]]:
"""
处理所有事件,根据事件类型分发给订阅的处理器。
"""
@@ -89,10 +89,10 @@ class EventsManager:
# 2. 获取并遍历处理器
handlers = self._events_subscribers.get(event_type, [])
if not handlers:
return True
return True, None
current_stream_id = transformed_message.stream_id if transformed_message else None
modified_message: Optional[MaiMessages] = None
for handler in handlers:
# 3. 前置检查和配置加载
if (
@@ -107,15 +107,19 @@ class EventsManager:
handler.set_plugin_config(plugin_config)
# 4. 根据类型分发任务
if handler.intercept_message or event_type == EventType.ON_STOP: # 让ON_STOP的所有事件处理器都发挥作用防止还没执行即被取消
if (
handler.intercept_message or event_type == EventType.ON_STOP
): # 让ON_STOP的所有事件处理器都发挥作用防止还没执行即被取消
# 阻塞执行,并更新 continue_flag
should_continue = await self._dispatch_intercepting_handler(handler, event_type, transformed_message)
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
handler, event_type, modified_message or transformed_message
)
continue_flag = continue_flag and should_continue
else:
# 异步执行,不阻塞
self._dispatch_handler_task(handler, event_type, transformed_message)
return continue_flag
return continue_flag, modified_message
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
@@ -327,16 +331,18 @@ class EventsManager:
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
async def _dispatch_intercepting_handler(
async def _dispatch_intercepting_handler_task(
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
) -> bool:
) -> Tuple[bool, Optional[MaiMessages]]:
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
success, continue_processing, return_message, custom_result = await handler.execute(message)
success, continue_processing, return_message, custom_result, modified_message = await handler.execute(
message
)
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
@@ -345,17 +351,17 @@ class EventsManager:
if self._history_enable_map[event_type] and custom_result:
self._events_result_history[event_type].append(custom_result)
return continue_processing
return continue_processing, modified_message
except KeyError:
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
return True
return True, None
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True # 发生异常时默认不中断其他处理
return True, None # 发生异常时默认不中断其他处理
def _task_done_callback(
self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None]],
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
event_type: EventType | str,
):
"""任务完成回调"""
@@ -365,7 +371,7 @@ class EventsManager:
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
success, _, result, custom_result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else: