重构整个插件系统,尝试恢复可启动性,新增插件系统maibot-plugin-sdk依赖
This commit is contained in:
@@ -19,9 +19,10 @@ from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.message_recorder import extract_and_distribute_messages
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.core.event_bus import event_bus
|
||||
from src.chat.event_helpers import build_event_message
|
||||
from src.services import generator_service as generator_api, send_service as send_api, message_service as message_api, database_service as database_api
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
@@ -315,8 +316,9 @@ class BrainChatting:
|
||||
message_id_list=message_id_list,
|
||||
prompt_key="brain_planner",
|
||||
)
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
_event_msg = build_event_message(EventType.ON_PLAN, llm_prompt=prompt_info[0], stream_id=self.chat_stream.stream_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.ON_PLAN, _event_msg
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
|
||||
@@ -23,8 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.core.component_registry import component_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
||||
|
||||
169
src/chat/event_helpers.py
Normal file
169
src/chat/event_helpers.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
事件消息构建工具
|
||||
|
||||
将 chat 层的消息对象 (SessionMessage / MessageSending) 转换为
|
||||
核心事件系统使用的 MaiMessages,供调用 event_bus.emit() 前使用。
|
||||
"""
|
||||
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.core.types import EventType, MaiMessages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import MessageSending, SessionMessage
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
logger = get_logger("event_helpers")
|
||||
|
||||
|
||||
def build_event_message(
|
||||
event_type: EventType | str,
|
||||
message: Optional["SessionMessage | MessageSending | MaiMessages"] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> Optional[MaiMessages]:
|
||||
"""根据事件类型和输入,准备和转换消息对象。
|
||||
|
||||
迁移自 events_manager._prepare_message,保持相同的行为。
|
||||
"""
|
||||
if isinstance(message, MaiMessages):
|
||||
return message.deepcopy()
|
||||
|
||||
if message:
|
||||
return _transform_event_message(message, llm_prompt, llm_response)
|
||||
|
||||
if event_type not in (EventType.ON_START, EventType.ON_STOP):
|
||||
assert stream_id, "如果没有消息,必须为非启动/关闭事件提供流ID"
|
||||
if event_type in (EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM):
|
||||
return _build_message_from_stream(stream_id, llm_prompt, llm_response)
|
||||
else:
|
||||
return _build_message_without_raw(stream_id, llm_prompt, llm_response, action_usage)
|
||||
|
||||
return None # ON_START / ON_STOP 没有消息体
|
||||
|
||||
|
||||
def _transform_event_message(
|
||||
message: "SessionMessage | MessageSending",
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
) -> MaiMessages:
|
||||
"""将 SessionMessage / MessageSending 转换为 MaiMessages。"""
|
||||
from maim_message import Seg
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
|
||||
transformed = MaiMessages(
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=llm_response.content if llm_response else None,
|
||||
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
||||
llm_response_model=llm_response.model if llm_response else None,
|
||||
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
||||
raw_message=message.processed_plain_text or "",
|
||||
additional_data={},
|
||||
)
|
||||
|
||||
# 消息段处理
|
||||
if isinstance(message, MessageSending):
|
||||
if message.message_segment.type == "seglist":
|
||||
transformed.message_segments = list(message.message_segment.data) # type: ignore
|
||||
else:
|
||||
transformed.message_segments = [message.message_segment]
|
||||
else:
|
||||
transformed.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
|
||||
|
||||
# stream_id
|
||||
transformed.stream_id = message.session_id if hasattr(message, "session_id") else ""
|
||||
|
||||
# 处理后文本
|
||||
transformed.plain_text = message.processed_plain_text
|
||||
|
||||
# 基本信息
|
||||
if isinstance(message, MessageSending):
|
||||
transformed.message_base_info["platform"] = message.platform
|
||||
if message.session.group_id:
|
||||
transformed.is_group_message = True
|
||||
group_name = ""
|
||||
if (
|
||||
message.session.context
|
||||
and message.session.context.message
|
||||
and message.session.context.message.message_info.group_info
|
||||
):
|
||||
group_name = message.session.context.message.message_info.group_info.group_name
|
||||
transformed.message_base_info.update(
|
||||
{
|
||||
"group_id": message.session.group_id,
|
||||
"group_name": group_name,
|
||||
}
|
||||
)
|
||||
transformed.message_base_info.update(
|
||||
{
|
||||
"user_id": message.bot_user_info.user_id,
|
||||
"user_cardname": message.bot_user_info.user_cardname,
|
||||
"user_nickname": message.bot_user_info.user_nickname,
|
||||
}
|
||||
)
|
||||
if not transformed.is_group_message:
|
||||
transformed.is_private_message = True
|
||||
elif hasattr(message, "message_info") and message.message_info:
|
||||
if message.platform:
|
||||
transformed.message_base_info["platform"] = message.platform
|
||||
if message.message_info.group_info:
|
||||
transformed.is_group_message = True
|
||||
transformed.message_base_info.update(
|
||||
{
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
}
|
||||
)
|
||||
if message.message_info.user_info:
|
||||
if not transformed.is_group_message:
|
||||
transformed.is_private_message = True
|
||||
transformed.message_base_info.update(
|
||||
{
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_cardname": message.message_info.user_info.user_cardname,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
}
|
||||
)
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
def _build_message_from_stream(
|
||||
stream_id: str,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
) -> MaiMessages:
|
||||
"""从 stream_id 查找会话消息并转换。"""
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
session = chat_manager.get_session_by_session_id(stream_id)
|
||||
assert session, f"未找到流ID为 {stream_id} 的会话"
|
||||
return _transform_event_message(session.context.message, llm_prompt, llm_response)
|
||||
|
||||
|
||||
def _build_message_without_raw(
|
||||
stream_id: str,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> MaiMessages:
|
||||
"""没有原始消息对象时,从 stream_id 构建最小 MaiMessages。"""
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
session = chat_manager.get_session_by_session_id(stream_id)
|
||||
assert session, f"未找到流ID为 {stream_id} 的会话"
|
||||
return MaiMessages(
|
||||
stream_id=stream_id,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=llm_response.content if llm_response else None,
|
||||
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
||||
llm_response_model=llm_response.model if llm_response else None,
|
||||
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
||||
is_group_message=session.is_group_session,
|
||||
is_private_message=not session.is_group_session,
|
||||
action_usage=action_usage,
|
||||
additional_data={"response_is_processed": True},
|
||||
)
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.services import send_service as send_api
|
||||
|
||||
from src.common.message_repository import count_messages
|
||||
|
||||
|
||||
@@ -11,8 +11,9 @@ from src.common.utils.utils_session import SessionUtils
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.core.component_registry import component_registry
|
||||
from src.core.types import EventType
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
@@ -65,10 +66,10 @@ class ChatBot:
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
|
||||
# 使用新的组件注册中心查找命令
|
||||
# 使用核心组件注册表查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
command_executor, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
|
||||
@@ -82,20 +83,20 @@ class ChatBot:
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
|
||||
# 创建命令实例
|
||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response, intercept_message_level = await command_instance.execute()
|
||||
# 调用命令执行器
|
||||
success, response, intercept_message_level = await command_executor(
|
||||
message=message,
|
||||
plugin_config=plugin_config,
|
||||
matched_groups=matched_groups,
|
||||
)
|
||||
message.intercept_message_level = intercept_message_level
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})")
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return (
|
||||
@@ -105,14 +106,9 @@ class ChatBot:
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
logger.error(f"执行命令时出错: {command_name} - {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
|
||||
@@ -318,11 +318,13 @@ class UniversalMessageSender:
|
||||
message.build_reply()
|
||||
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.core.event_bus import event_bus
|
||||
from src.chat.event_helpers import build_event_message
|
||||
from src.core.types import EventType
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id
|
||||
_event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.POST_SEND_PRE_PROCESS, _event_msg
|
||||
)
|
||||
if not continue_flag:
|
||||
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
||||
@@ -336,8 +338,9 @@ class UniversalMessageSender:
|
||||
|
||||
await message.process()
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_SEND, message=message, stream_id=chat_id
|
||||
_event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.POST_SEND, _event_msg
|
||||
)
|
||||
if not continue_flag:
|
||||
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
|
||||
@@ -360,8 +363,9 @@ class UniversalMessageSender:
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_SEND, message=message, stream_id=chat_id
|
||||
_event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.AFTER_SEND, _event_msg
|
||||
)
|
||||
if not continue_flag:
|
||||
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
|
||||
|
||||
@@ -1,20 +1,34 @@
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.core.component_registry import component_registry, ActionExecutor
|
||||
from src.core.types import ActionInfo, ComponentType
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
class ActionHandle:
|
||||
"""Action 执行句柄
|
||||
|
||||
不依赖任何插件基类,内部持有 executor (async callable) 和绑定参数。
|
||||
brain_chat 调用 ``await handle.execute()`` 即可。
|
||||
"""
|
||||
|
||||
def __init__(self, executor: ActionExecutor, **kwargs):
|
||||
self._executor = executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
return await self._executor(**self._kwargs)
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
现在统一使用新插件系统,简化了原有的新旧兼容逻辑。
|
||||
使用核心组件注册表的 executor-based 模式。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -39,9 +53,9 @@ class ActionManager:
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[DatabaseMessages] = None,
|
||||
) -> Optional[BaseAction]:
|
||||
) -> Optional[ActionHandle]:
|
||||
"""
|
||||
创建动作处理器实例
|
||||
创建动作执行句柄
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
@@ -52,30 +66,26 @@ class ActionManager:
|
||||
chat_stream: 聊天流
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
action_message: 动作消息记录
|
||||
|
||||
Returns:
|
||||
Optional[BaseAction]: 创建的动作处理器实例,如果动作名称未注册则返回None
|
||||
Optional[ActionHandle]: 执行句柄,如果动作未注册则返回 None
|
||||
"""
|
||||
try:
|
||||
# 获取组件类 - 明确指定查询Action类型
|
||||
component_class: Type[BaseAction] = component_registry.get_component_class(
|
||||
action_name, ComponentType.ACTION
|
||||
) # type: ignore
|
||||
if not component_class:
|
||||
executor = component_registry.get_action_executor(action_name)
|
||||
if not executor:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
|
||||
# 获取组件信息
|
||||
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
|
||||
if not component_info:
|
||||
info = component_registry.get_action_info(action_name)
|
||||
if not info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
return None
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
plugin_config = component_registry.get_plugin_config(info.plugin_name) or {}
|
||||
|
||||
# 创建动作实例
|
||||
instance = component_class(
|
||||
handle = ActionHandle(
|
||||
executor,
|
||||
action_data=action_data,
|
||||
action_reasoning=action_reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
@@ -87,11 +97,11 @@ class ActionManager:
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
logger.debug(f"创建Action实例成功: {action_name}")
|
||||
return instance
|
||||
logger.debug(f"创建Action执行句柄成功: {action_name}")
|
||||
return handle
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建Action实例失败 {action_name}: {e}")
|
||||
logger.error(f"创建Action执行句柄失败 {action_name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -7,8 +7,8 @@ from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.core.types import ActionActivationType, ActionInfo
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.core.types import ActionActivationType, ActionInfo, ComponentType
|
||||
from src.core.component_registry import component_registry
|
||||
from src.services.message_service import translate_pid_to_description
|
||||
from src.person_info.person_info import Person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -27,12 +27,12 @@ from src.chat.utils.chat_message_builder import (
|
||||
replace_user_references,
|
||||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.services.message_service import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
@@ -56,7 +56,7 @@ class DefaultReplyer:
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
@@ -149,11 +149,13 @@ class DefaultReplyer:
|
||||
except Exception:
|
||||
logger.exception("记录reply日志失败")
|
||||
return False, llm_response
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.core.event_bus import event_bus
|
||||
from src.chat.event_helpers import build_event_message
|
||||
|
||||
if not from_plugin:
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
||||
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.POST_LLM, _event_msg
|
||||
)
|
||||
if not continue_flag:
|
||||
raise UserWarning("插件于请求前中断了内容生成")
|
||||
@@ -217,8 +219,9 @@ class DefaultReplyer:
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("记录reply日志失败")
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.AFTER_LLM, _event_msg
|
||||
)
|
||||
if not from_plugin and not continue_flag:
|
||||
raise UserWarning("插件于请求后取消了内容生成")
|
||||
|
||||
@@ -28,12 +28,12 @@ from src.chat.utils.chat_message_builder import (
|
||||
replace_user_references,
|
||||
)
|
||||
from src.bw_learner.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.services.message_service import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
from src.person_info.person_info import Person, is_person_known
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.core.types import ActionInfo, EventType
|
||||
from src.services import llm_service as llm_api
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
from src.bw_learner.jargon_explainer import explain_jargon_in_context
|
||||
|
||||
@@ -55,7 +55,7 @@ class PrivateReplyer:
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
from src.chat.tool_executor import ToolExecutor
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
@@ -114,11 +114,13 @@ class PrivateReplyer:
|
||||
if not prompt:
|
||||
logger.warning("构建prompt失败,跳过回复生成")
|
||||
return False, llm_response
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.core.event_bus import event_bus
|
||||
from src.chat.event_helpers import build_event_message
|
||||
|
||||
if not from_plugin:
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
||||
_event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.POST_LLM, _event_msg
|
||||
)
|
||||
if not continue_flag:
|
||||
raise UserWarning("插件于请求前中断了内容生成")
|
||||
@@ -138,8 +140,9 @@ class PrivateReplyer:
|
||||
llm_response.reasoning = reasoning_content
|
||||
llm_response.model = model_name
|
||||
llm_response.tool_calls = tool_call
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||
_event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id)
|
||||
continue_flag, modified_message = await event_bus.emit(
|
||||
EventType.AFTER_LLM, _event_msg
|
||||
)
|
||||
if not from_plugin and not continue_flag:
|
||||
raise UserWarning("插件于请求后取消了内容生成")
|
||||
|
||||
250
src/chat/tool_executor.py
Normal file
250
src/chat/tool_executor.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
工具执行器
|
||||
|
||||
独立的工具执行组件,可以直接输入聊天消息内容,
|
||||
自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
|
||||
从 src.plugin_system.core.tool_use 迁移,使用新的核心组件注册表。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.core.component_registry import component_registry
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.tool_cache: Dict[str, dict] = {}
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
"""从聊天消息执行工具"""
|
||||
|
||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], ""
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, ""
|
||||
|
||||
tools = self._get_tool_definitions()
|
||||
if not tools:
|
||||
logger.debug(f"{self.log_prefix}没有可用工具,直接返回空内容")
|
||||
return [], [], ""
|
||||
|
||||
prompt_template = prompt_manager.get_prompt("tool_executor")
|
||||
prompt_template.add_context("target_message", target_message)
|
||||
prompt_template.add_context("chat_history", chat_history)
|
||||
prompt_template.add_context("sender", sender)
|
||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt_template.add_context("time_now", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
if tool_results:
|
||||
self._set_cache(cache_key, tool_results)
|
||||
|
||||
if used_tools:
|
||||
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
|
||||
|
||||
if return_details:
|
||||
return tool_results, used_tools, prompt
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
all_tools = component_registry.get_llm_available_tools()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [info.get_llm_definition() for name, info in all_tools.items() if name not in user_disabled_tools]
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用列表"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
used_tools: List[str] = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.func_name
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
tool_info["content"] = str(content)
|
||||
content_check = tool_info["content"]
|
||||
if (isinstance(content_check, str) and not content_check.strip()) or (
|
||||
isinstance(content_check, (list, tuple)) and len(content_check) == 0
|
||||
):
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}无有效内容,跳过展示")
|
||||
continue
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
preview = str(content)[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(error_info)
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用"""
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True
|
||||
|
||||
executor = component_registry.get_tool_executor(function_name)
|
||||
if not executor:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
result = await executor(function_args)
|
||||
if result:
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
"""直接执行指定工具"""
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=f"direct_tool_{time.time()}",
|
||||
func_name=tool_name,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"direct_tool_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
# === 缓存方法 ===
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
content = f"{target_message}_{chat_history}_{sender}"
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||
return None
|
||||
cache_item = self.tool_cache[cache_key]
|
||||
if cache_item["ttl"] <= 0:
|
||||
del self.tool_cache[cache_key]
|
||||
return None
|
||||
cache_item["ttl"] -= 1
|
||||
return cache_item["result"]
|
||||
|
||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||
if not self.enable_cache:
|
||||
return
|
||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
if not self.enable_cache:
|
||||
return
|
||||
expired = [k for k, v in self.tool_cache.items() if v["ttl"] <= 0]
|
||||
for key in expired:
|
||||
del self.tool_cache[key]
|
||||
|
||||
def clear_cache(self):
|
||||
if self.enable_cache:
|
||||
self.tool_cache.clear()
|
||||
|
||||
def get_cache_status(self) -> Dict:
|
||||
if not self.enable_cache:
|
||||
return {"enabled": False, "cache_count": 0}
|
||||
self._cleanup_expired_cache()
|
||||
ttl_distribution: Dict[int, int] = {}
|
||||
for item in self.tool_cache.values():
|
||||
ttl = item["ttl"]
|
||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||
return {
|
||||
"enabled": True,
|
||||
"cache_count": len(self.tool_cache),
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"ttl_distribution": ttl_distribution,
|
||||
}
|
||||
|
||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||
if enable_cache is not None:
|
||||
self.enable_cache = enable_cache
|
||||
if cache_ttl > 0:
|
||||
self.cache_ttl = cache_ttl
|
||||
Reference in New Issue
Block a user