应要求提交上未完成的HFC, expression部分
This commit is contained in:
@@ -1,21 +1,23 @@
|
||||
import traceback
|
||||
import os
|
||||
import re
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo, Seg, GroupInfo
|
||||
from maim_message import MessageBase
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.chat.message_receive.message_old import MessageRecv
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
@@ -25,50 +27,6 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def _check_ban_words(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否包含过滤词
|
||||
"""
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, userinfo: UserInfo, group_info: Optional[GroupInfo] = None) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配过滤正则
|
||||
"""
|
||||
# 检查text是否为None或空字符串
|
||||
if text is None or not text:
|
||||
return False
|
||||
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
self.bot = None # bot 实例引用
|
||||
@@ -100,9 +58,11 @@ class ChatBot:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_commands(self, message: MessageRecv):
|
||||
async def _process_commands(self, message: SessionMessage):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
if not message.processed_plain_text:
|
||||
return False, None, True # 没有文本内容,继续处理消息
|
||||
try:
|
||||
text = message.processed_plain_text
|
||||
|
||||
@@ -112,11 +72,8 @@ class ChatBot:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if (
|
||||
message.chat_stream
|
||||
and message.chat_stream.stream_id
|
||||
and command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||
if message.session_id and command_name in global_announcement_manager.get_disabled_chat_commands(
|
||||
message.session_id
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
@@ -269,97 +226,115 @@ class ChatBot:
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
maim_raw_message = MessageBase.from_dict(message_data)
|
||||
message = SessionMessage.from_maim_message(maim_raw_message)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_MESSAGE_PRE_PROCESS, message
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
message.platform,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
group_id=group_info.group_id if group_info else None,
|
||||
)
|
||||
if not continue_flag:
|
||||
return
|
||||
if modified_message and modified_message._modify_flags.modify_message_segments:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
|
||||
if await self.handle_notice_message(message):
|
||||
pass
|
||||
message.session_id = session_id # 正确初始化session_id
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
# TODO: 修复事件预处理部分
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
# EventType.ON_MESSAGE_PRE_PROCESS, message
|
||||
# )
|
||||
# if not continue_flag:
|
||||
# return
|
||||
# if modified_message and modified_message._modify_flags.modify_message_segments:
|
||||
# message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
|
||||
# TODO: notice消息处理
|
||||
# if await self.handle_notice_message(message):
|
||||
# pass
|
||||
|
||||
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
|
||||
await message.process()
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text,
|
||||
user_info, # type: ignore
|
||||
group_info,
|
||||
) or _check_ban_regex(
|
||||
message.raw_message, # type: ignore
|
||||
user_info, # type: ignore
|
||||
group_info,
|
||||
):
|
||||
text = message.processed_plain_text or ""
|
||||
is_banned, word = MessageUtils.check_ban_words(text)
|
||||
if is_banned:
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{user_info.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return
|
||||
is_banned_regex, pattern = MessageUtils.check_ban_regex(text)
|
||||
if is_banned_regex:
|
||||
chat_name = group_info.group_name if group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{user_info.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat_manager.register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
platform = message.platform
|
||||
user_id = user_info.user_id
|
||||
group_id = group_info.group_id if group_info else None
|
||||
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# if await self.check_ban_content(message):
|
||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||
# return
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
# TODO: 在新命令系统完成后恢复这里
|
||||
# 命令处理 - 使用新插件系统检查并处理命令
|
||||
is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
# is_command, cmd_result, continue_process = await self._process_commands(message)
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
await MessageStorage.store_message(message, chat)
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
# if is_command and not continue_process:
|
||||
# await MessageStorage.store_message(message, chat)
|
||||
# logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
# return
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
if not continue_flag:
|
||||
return
|
||||
if modified_message and modified_message._modify_flags.modify_plain_text:
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
# if not continue_flag:
|
||||
# return
|
||||
# if modified_message and modified_message._modify_flags.modify_plain_text:
|
||||
# message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
for k in template_items.keys():
|
||||
await Prompt.create_async(template_items[k], k)
|
||||
logger.debug(f"注册{template_items[k]},{k}")
|
||||
else:
|
||||
template_group_name = None
|
||||
# # 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
# if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
# template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
# template_items = message.message_info.template_info.template_items
|
||||
# async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
# if isinstance(template_items, dict):
|
||||
# for k in template_items.keys():
|
||||
# await Prompt.create_async(template_items[k], k)
|
||||
# logger.debug(f"注册{template_items[k]},{k}")
|
||||
# else:
|
||||
# template_group_name = None
|
||||
|
||||
# async def preprocess():
|
||||
# # 根据聊天类型路由消息
|
||||
# if group_info is None:
|
||||
# # 私聊消息 -> PFC系统
|
||||
# logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||
# await MessageStorage.store_message(message, chat)
|
||||
# await self._create_pfc_chat(message)
|
||||
# else:
|
||||
# # 群聊消息 -> HeartFlow系统
|
||||
# logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||
# await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
# if template_group_name:
|
||||
# async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
# await preprocess()
|
||||
# else:
|
||||
# await preprocess()
|
||||
async def preprocess():
|
||||
# 根据聊天类型路由消息
|
||||
if group_info is None:
|
||||
# 私聊消息 -> PFC系统
|
||||
logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||
await MessageStorage.store_message(message, chat)
|
||||
MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
||||
await self._create_pfc_chat(message)
|
||||
else:
|
||||
# 群聊消息 -> HeartFlow系统
|
||||
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
await preprocess()
|
||||
else:
|
||||
await preprocess()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
Reference in New Issue
Block a user