增加了event_handler修改内容的方法
This commit is contained in:
@@ -37,19 +37,19 @@ class BaseEventHandler(ABC):
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self, message: MaiMessages | None
|
||||
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]:
|
||||
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
Args:
|
||||
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
||||
Returns:
|
||||
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果)
|
||||
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 execute 方法")
|
||||
|
||||
@classmethod
|
||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||
"""获取事件处理器的信息"""
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成S
|
||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||
if "." in name:
|
||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
@@ -7,6 +8,7 @@ from maim_message import Seg
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
@@ -56,6 +58,7 @@ class EventType(Enum):
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
||||
ON_MESSAGE_PRE_PROCESS = "on_message_pre_process"
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
@@ -116,9 +119,9 @@ class ActionInfo(ComponentInfo):
|
||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||
# 激活类型相关
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
@@ -154,7 +157,9 @@ class CommandInfo(ComponentInfo):
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||
default_factory=list
|
||||
) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -233,6 +238,15 @@ class PluginInfo:
|
||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModifyFlag:
|
||||
modify_message_segments: bool = False
|
||||
modify_plain_text: bool = False
|
||||
modify_llm_prompt: bool = False
|
||||
modify_llm_response_content: bool = False
|
||||
modify_llm_response_reasoning: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
@@ -263,31 +277,129 @@ class MaiMessages:
|
||||
|
||||
llm_response_content: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
"""使用的Action"""
|
||||
|
||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||
"""附加数据,可以存储额外信息"""
|
||||
|
||||
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.message_segments is None:
|
||||
self.message_segments = []
|
||||
|
||||
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
|
||||
"""
|
||||
修改消息段列表
|
||||
|
||||
Warning:
|
||||
在生成了plain_text的情况下调用此方法,可能会导致plain_text内容与消息段不一致
|
||||
|
||||
Args:
|
||||
new_segments (List[Seg]): 新的消息段列表
|
||||
"""
|
||||
if self.plain_text and not suppress_warning:
|
||||
warnings.warn(
|
||||
"修改消息段后,plain_text可能与消息段内容不一致,建议同时更新plain_text",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.message_segments = new_segments
|
||||
self._modify_flags.modify_message_segments = True
|
||||
|
||||
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改LLM提示词
|
||||
|
||||
Warning:
|
||||
在没有生成llm_prompt的情况下调用此方法,可能会导致修改无效
|
||||
|
||||
Args:
|
||||
new_prompt (str): 新的提示词内容
|
||||
"""
|
||||
if self.llm_prompt is None and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前llm_prompt为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.llm_prompt = new_prompt
|
||||
self._modify_flags.modify_llm_prompt = True
|
||||
|
||||
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的plain_text内容
|
||||
|
||||
Warning:
|
||||
在未生成plain_text的情况下调用此方法,可能会导致plain_text为空或者修改无效
|
||||
|
||||
Args:
|
||||
new_text (str): 新的纯文本内容
|
||||
"""
|
||||
if not self.plain_text and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前plain_text为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.plain_text = new_text
|
||||
self._modify_flags.modify_plain_text = True
|
||||
|
||||
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的llm_response_content内容
|
||||
|
||||
Warning:
|
||||
在未生成llm_response_content的情况下调用此方法,可能会导致llm_response_content为空或者修改无效
|
||||
|
||||
Args:
|
||||
new_content (str): 新的LLM响应内容
|
||||
"""
|
||||
if not self.llm_response_content and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前llm_response_content为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.llm_response_content = new_content
|
||||
self._modify_flags.modify_llm_response_content = True
|
||||
|
||||
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的llm_response_reasoning内容
|
||||
|
||||
Warning:
|
||||
在未生成llm_response_reasoning的情况下调用此方法,可能会导致llm_response_reasoning为空或者修改无效
|
||||
|
||||
Args:
|
||||
new_reasoning (str): 新的LLM响应推理内容
|
||||
"""
|
||||
if not self.llm_response_reasoning and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前llm_response_reasoning为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.llm_response_reasoning = new_reasoning
|
||||
self._modify_flags.modify_llm_response_reasoning = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomEventHandlerResult:
|
||||
message: str = ""
|
||||
timestamp: float = 0.0
|
||||
extra_info: Optional[Dict] = None
|
||||
extra_info: Optional[Dict] = None
|
||||
|
||||
@@ -71,7 +71,7 @@ class EventsManager:
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||
"""
|
||||
处理所有事件,根据事件类型分发给订阅的处理器。
|
||||
"""
|
||||
@@ -89,10 +89,10 @@ class EventsManager:
|
||||
# 2. 获取并遍历处理器
|
||||
handlers = self._events_subscribers.get(event_type, [])
|
||||
if not handlers:
|
||||
return True
|
||||
return True, None
|
||||
|
||||
current_stream_id = transformed_message.stream_id if transformed_message else None
|
||||
|
||||
modified_message: Optional[MaiMessages] = None
|
||||
for handler in handlers:
|
||||
# 3. 前置检查和配置加载
|
||||
if (
|
||||
@@ -107,15 +107,19 @@ class EventsManager:
|
||||
handler.set_plugin_config(plugin_config)
|
||||
|
||||
# 4. 根据类型分发任务
|
||||
if handler.intercept_message or event_type == EventType.ON_STOP: # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
||||
if (
|
||||
handler.intercept_message or event_type == EventType.ON_STOP
|
||||
): # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
||||
# 阻塞执行,并更新 continue_flag
|
||||
should_continue = await self._dispatch_intercepting_handler(handler, event_type, transformed_message)
|
||||
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
|
||||
handler, event_type, modified_message or transformed_message
|
||||
)
|
||||
continue_flag = continue_flag and should_continue
|
||||
else:
|
||||
# 异步执行,不阻塞
|
||||
self._dispatch_handler_task(handler, event_type, transformed_message)
|
||||
|
||||
return continue_flag
|
||||
return continue_flag, modified_message
|
||||
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||
@@ -327,16 +331,18 @@ class EventsManager:
|
||||
except Exception as e:
|
||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
||||
|
||||
async def _dispatch_intercepting_handler(
|
||||
async def _dispatch_intercepting_handler_task(
|
||||
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
||||
) -> bool:
|
||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
||||
if event_type == EventType.UNKNOWN:
|
||||
raise ValueError("未知事件类型")
|
||||
if event_type not in self._history_enable_map:
|
||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||
try:
|
||||
success, continue_processing, return_message, custom_result = await handler.execute(message)
|
||||
success, continue_processing, return_message, custom_result, modified_message = await handler.execute(
|
||||
message
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
|
||||
@@ -345,17 +351,17 @@ class EventsManager:
|
||||
|
||||
if self._history_enable_map[event_type] and custom_result:
|
||||
self._events_result_history[event_type].append(custom_result)
|
||||
return continue_processing
|
||||
return continue_processing, modified_message
|
||||
except KeyError:
|
||||
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
||||
return True
|
||||
return True, None
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
||||
return True # 发生异常时默认不中断其他处理
|
||||
return True, None # 发生异常时默认不中断其他处理
|
||||
|
||||
def _task_done_callback(
|
||||
self,
|
||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None]],
|
||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
||||
event_type: EventType | str,
|
||||
):
|
||||
"""任务完成回调"""
|
||||
@@ -365,7 +371,7 @@ class EventsManager:
|
||||
if event_type not in self._history_enable_map:
|
||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||
try:
|
||||
success, _, result, custom_result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
|
||||
if success:
|
||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user