diff --git a/pytests/utils_test/message_utils_test.py b/pytests/utils_test/message_utils_test.py index 97441437..60c32cbb 100644 --- a/pytests/utils_test/message_utils_test.py +++ b/pytests/utils_test/message_utils_test.py @@ -203,7 +203,7 @@ def load_message_via_file(monkeypatch): def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str: return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为 -def dummy_is_bot_self(user_id: str) -> bool: +def dummy_is_bot_self(user_id: str, platform) -> bool: return user_id == "bot_self" def load_utils_via_file(monkeypatch): @@ -212,6 +212,8 @@ def load_utils_via_file(monkeypatch): # Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用 math_utils_mod = ModuleType("src.common.utils.math_utils") math_utils_mod.number_to_short_id = dummy_number_to_short_id + math_utils_mod.TimestampMode = type("TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}) + math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: "2024-01-01 12:00:00" # 返回固定的时间字符串 monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod) # 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析 @@ -252,7 +254,7 @@ async def test_build_readable_message_basic(monkeypatch): user_info = UserInfo(user_id="u1", user_nickname="Alice") msg.message_info = MessageInfo(user_info=user_info) msg.raw_message = MessageSequence([TextComponent("Hello world")]) - text, mapping = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True) + text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True) assert "[1] Alice说:Hello world" in text assert mapping == {} @@ -270,7 +272,7 @@ async def test_build_readable_message_anonymize(monkeypatch): user_info = UserInfo(user_id="u42", user_nickname="Bob") msg.message_info = MessageInfo(user_info=user_info) msg.raw_message = MessageSequence([TextComponent("Secret text")]) - text, mapping = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False) + text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False) # 根据实现,original_name 为 user_nickname,因此文本中应包含原始名称 assert "XXXXXX说:" in text assert "u42" in mapping @@ -291,7 +293,7 @@ async def test_build_readable_message_replace_bot(monkeypatch): user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot") msg.message_info = MessageInfo(user_info=user_info) msg.raw_message = MessageSequence([TextComponent("ping")]) - text, mapping = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot") + text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot") assert "MAIBot说:ping" in text @@ -309,7 +311,7 @@ async def test_build_readable_message_image_extraction(monkeypatch): msg.session_id = "s_img" msg.raw_message = MessageSequence([img]) msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser")) - text, mapping = await MessageUtils.build_readable_message([msg], extract_pictures=True) + text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True) # 应包含图片描述占位 assert "图片1" in text # mapping 不为空(匿名化未开启则为空) @@ -333,7 +335,7 @@ async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno( msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot")) msg1.raw_message = MessageSequence([TextComponent("Hi")]) msg2.raw_message = MessageSequence([TextComponent("Hello")]) - text, mapping = await MessageUtils.build_readable_message( + text, mapping, _ = await MessageUtils.build_readable_message( [msg1, msg2], anonymize=True, replace_bot_name=True, @@ -361,7 +363,7 @@ async def test_build_readable_message_with_at(monkeypatch): msg.session_id = "s_at" msg.raw_message = MessageSequence([at_comp]) msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser")) - text, mapping = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot") + text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot") # 验证主消息和@组件中的用户信息都被处理 assert "XXXXXX说:" in text # 主消息用户被匿名化 assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化 \ No newline at end of file diff --git a/src/bw_learner/expression_learner.py b/src/bw_learner/expression_learner.py index 54e413ad..1f879426 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -53,7 +53,7 @@ class ExpressionLearner: if not self._messages_cache: logger.debug("没有消息可供学习,跳过学习过程") return - readable_message, _ = await MessageUtils.build_readable_message( + readable_message, _, _ = await MessageUtils.build_readable_message( self._messages_cache, anonymize=True, show_lineno=True, diff --git a/src/common/utils/math_utils.py b/src/common/utils/math_utils.py index 9b70296c..599946f2 100644 --- a/src/common/utils/math_utils.py +++ b/src/common/utils/math_utils.py @@ -47,7 +47,7 @@ def number_to_short_id(original_id: int, salt: str, length: int = 6) -> str: return short_id -def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode) -> str: +def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode | str) -> str: """将时间戳按照指定模式转换为人类可读的格式 Args: @@ -56,6 +56,11 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode) Returns: str: 转换后的时间字符串 """ + if isinstance(mode, str): + if mode.upper() in TimestampMode.__members__: + mode = TimestampMode[mode.upper()] + else: + raise ValueError(f"不支持的时间戳转换模式: {mode}") if mode in [TimestampMode.NORMAL, TimestampMode.NORMAL_NO_YMD]: return time.strftime(mode.value, time.localtime(timestamp)) elif mode == TimestampMode.RELATIVE: diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 03939e8a..e2ad4256 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -1,6 +1,5 @@ -from enum import Enum from maim_message import MessageBase, Seg -from typing import List, Tuple, Optional, Dict, TYPE_CHECKING +from typing import List, Tuple, Optional, Dict, TYPE_CHECKING, Callable import base64 import hashlib @@ -136,7 +135,7 @@ class MessageUtils: @staticmethod def store_message_to_db(message: "SessionMessage"): - """存储消息到数据库""" + """存储消息到数据库,此方法没有update机制""" from src.common.database.database import get_db_session with get_db_session() as session: @@ -152,10 +151,12 @@ class MessageUtils: extract_pictures: bool = False, replace_bot_name: bool = False, target_bot_name: Optional[str] = None, - timestamp_mode: Optional[TimestampMode] = None, + timestamp_mode: Optional[TimestampMode | str] = None, show_message_id_prefix: bool = False, + read_mark_time: Optional[float] = None, truncate_message: bool = False, - ) -> Tuple[str, Dict[str, Tuple[str, str]]]: + truncate_func: Optional[Callable[[float], Tuple[Optional[int], str]]] = None, + ) -> Tuple[str, Dict[str, Tuple[str, str]], List[str]]: """ 将消息构建为LLM可读的文本格式 @@ -168,17 +169,21 @@ class MessageUtils: target_bot_name (Optional[str]): 如果replace_bot_name为True,指定要替换的机器人名称,比如可以把机器人名称替换为“你” timestamp_mode (Optional[TimestampMode]): 时间戳显示模式,默认为None表示不显示时间戳 show_message_id_prefix (bool): 是否在每条消息前显示消息ID前缀 - truncate_message (bool): 是否截断过长的消息文本,避免生成过长的输入给LLM + truncate_message (bool): 是否启用消息文本截断功能,截断过长的消息文本 + truncate_func (Optional[Callable[[float], Tuple[Optional[int], str]]]) 截断函数,接受消息的百分位位置(0-1),返回一个元组(文本长度限制(可为None表不切割), 替换内容) Returns: - return (Tuple[str, Dict[str, Tuple[str, str]]]): 构建后的消息文本,以及映射表(匿名ID, 原始名称) + return (Tuple[str, Dict[str, Tuple[str, str]], List[str]]): 构建后的消息文本,映射表 {用户ID: (匿名ID, 原始名称)},消息编号列表 """ msg_list: List["SessionMessage"] = messages user_id_mapping: Dict[str, Tuple[str, str]] = {} # user_id -> (匿名ID, 原始名称) + message_ids: List[str] = [] # 存储消息编号的列表 copied: bool = False # 标记是否已经复制过消息列表,避免不必要的复制开销 img_map: Optional[Dict[str, Tuple[int, str]]] = None emoji_map: Optional[Dict[str, Tuple[int, str]]] = None if replace_bot_name and not target_bot_name: raise ValueError("当replace_bot_name为True时,必须指定target_bot_name参数") + + # 匿名化和机器人名称处理 if anonymize or replace_bot_name: user_id_mapping = {} # 利用弱引用直接传入并得到修改结果 anonymous_messages: List["SessionMessage"] = [] @@ -198,28 +203,99 @@ class MessageUtils: copied = True processed_plain_texts: List[str] = [] + + # 将图片提取到内容最前面 if extract_pictures: img_map = {} # binary_hash -> (图片ID, 描述信息) emoji_map = {} # binary_hash -> (表情ID, 描述信息) msg_list = [ MessageUtils._extract_pictures_from_message(msg, img_map, emoji_map, copied) for msg in msg_list ] + processed_plain_texts.append("图片信息和表情信息:") processed_plain_texts.extend(f"[图片{img_id}: {desc}]" for img_id, desc in img_map.values()) processed_plain_texts.append("") # 图片和表情之间添加一个换行,避免连在一起 processed_plain_texts.extend(f"[表情{emoji_id}: {desc}]" for emoji_id, desc in emoji_map.values()) - processed_plain_texts.append("") # 表情和消息文本之间添加两个换行,避免连在一起 + processed_plain_texts.extend(("", "聊天记录信息:")) - lineno_counter = 1 - for msg in msg_list: + msg_count = len(msg_list) + read_mark_added_flag: bool = False # 标记是否已经添加过已读标签,确保只添加一次 + for i, msg in enumerate(msg_list): await msg.process() plain_text: str = msg.processed_plain_text # type: ignore - usr_info = msg.message_info.user_info - usr_name = usr_info.user_cardname or usr_info.user_nickname or "未知用户" - header = f"[{lineno_counter}] {usr_name}说:" if show_lineno else f"{usr_name}说:" - lineno_counter += 1 + if truncate_message: # 消息截断逻辑 + percentile = i / msg_count + if not read_mark_time: # 没有已读标签 + plain_text = MessageUtils._truncate_message( + percentile, + plain_text, + truncate_func or MessageUtils._default_truncate_func, + ) + elif msg.timestamp.timestamp() < read_mark_time: + plain_text = MessageUtils._truncate_message( + percentile, + plain_text, + truncate_func or MessageUtils._default_truncate_func, + ) + elif not read_mark_added_flag: + read_mark_added_flag = True + processed_plain_texts.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n") + header, message_id = MessageUtils._build_line_header( + msg, + show_message_id_prefix, + i + 1, + show_lineno, + timestamp_mode, + ) + if message_id is not None: + message_ids.append(message_id) processed_plain_texts.append("".join([header, plain_text])) - return "\n".join(processed_plain_texts), user_id_mapping + return "\n".join(processed_plain_texts), user_id_mapping, message_ids + + @staticmethod + def _build_line_header( + message: "SessionMessage", + show_message_id_prefix: bool, + counter: int, + show_lineno: bool, + timestamp_mode: Optional[str | TimestampMode] = None, + ) -> Tuple[str, Optional[str]]: + usr_info = message.message_info.user_info + usr_name = usr_info.user_cardname or usr_info.user_nickname or "未知用户" + header_parts = [f"{usr_name}说:"] + message_id = None + if show_message_id_prefix: + rand_id = f"{counter}{random.randint(10, 99)}" + message_id = f"m{rand_id}" + header_parts.insert(0, f"[消息ID: {message_id}]") + if timestamp_mode: + timestamp_str = translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode) + header_parts.insert(0, f"[{timestamp_str}]") + if show_lineno: + header_parts.insert(0, f"[{counter}]") + return " ".join(header_parts), message_id + + @staticmethod + def _truncate_message( + percentile: float, original_content: str, truncate_func: Callable[[float], Tuple[Optional[int], str]] + ): + limit, replacement = truncate_func(percentile) + if limit: + return f"{original_content[:limit]}{replacement}" + else: + return original_content + + @staticmethod + def _default_truncate_func(percentile: float) -> Tuple[int, str]: + """默认的截断函数,根据消息在消息列表中的位置返回不同的截断长度和替换内容""" + if percentile < 0.3: + return 400, "......(内容太长了)" + elif percentile < 0.5: + return 200, "......(内容太长了)" + elif percentile < 0.8: + return 100, "......(有点记不清了)" + else: + return 50, "......(记不清了)" @staticmethod def _process_usr_info( @@ -232,11 +308,13 @@ class MessageUtils: ): """处理消息中的用户信息,进行匿名化显示""" new_message = message.deepcopy() + platform = message.platform new_component_list = [ MessageUtils._process_msg_component( component, anonymize_mapping, salt, + platform, anonymize, replace_bot_name, target_bot_name, @@ -254,7 +332,7 @@ class MessageUtils: anonymous_name = anonymize_mapping[msg_usr_info.user_id][0] new_message.message_info.user_info.user_nickname = anonymous_name new_message.message_info.user_info.user_cardname = anonymous_name - if replace_bot_name and target_bot_name and is_bot_self(msg_usr_info.user_id): + if replace_bot_name and target_bot_name and is_bot_self(msg_usr_info.user_id, platform): new_message.message_info.user_info.user_nickname = target_bot_name new_message.message_info.user_info.user_cardname = target_bot_name return new_message @@ -264,6 +342,7 @@ class MessageUtils: component: StandardMessageComponents, anonymize_mapping: Dict[str, Tuple[str, str]], salt: str, + platform: str, anonymize: bool, replace_bot_name: bool, target_bot_name: Optional[str] = None, @@ -274,6 +353,7 @@ class MessageUtils: component, anonymize_mapping, salt, + platform, anonymize, replace_bot_name, target_bot_name, @@ -283,6 +363,7 @@ class MessageUtils: component, anonymize_mapping, salt, + platform, anonymize, replace_bot_name, target_bot_name, @@ -292,6 +373,7 @@ class MessageUtils: component, anonymize_mapping, salt, + platform, anonymize, replace_bot_name, target_bot_name, @@ -303,6 +385,7 @@ class MessageUtils: component: AtComponent, anonymize_mapping: Dict[str, Tuple[str, str]], salt: str, + platform: str, anonymize: bool, replace_bot_name: bool, target_bot_name: Optional[str] = None, @@ -319,7 +402,7 @@ class MessageUtils: anonymous_name = anonymize_mapping[user_id][0] component.target_user_nickname = anonymous_name component.target_user_cardname = anonymous_name - if replace_bot_name and target_bot_name and is_bot_self(user_id): + if replace_bot_name and target_bot_name and is_bot_self(user_id, platform): component.target_user_nickname = target_bot_name component.target_user_cardname = target_bot_name return component @@ -329,6 +412,7 @@ class MessageUtils: component: ForwardNodeComponent, anonymize_mapping: Dict[str, Tuple[str, str]], salt: str, + platform: str, anonymize: bool, replace_bot_name: bool, target_bot_name: Optional[str] = None, @@ -354,7 +438,7 @@ class MessageUtils: anonymous_name = anonymize_mapping[user_id][0] comp.user_nickname = anonymous_name comp.user_cardname = anonymous_name - if replace_bot_name and target_bot_name and is_bot_self(user_id): + if replace_bot_name and target_bot_name and is_bot_self(user_id, platform): comp.user_nickname = target_bot_name comp.user_cardname = target_bot_name comp.content = [ # 递归处理转发消息中的组件 @@ -362,6 +446,7 @@ class MessageUtils: c, anonymize_mapping, salt, + platform, anonymize, replace_bot_name, target_bot_name, @@ -375,6 +460,7 @@ class MessageUtils: component: ReplyComponent, anonymize_mapping: Dict[str, Tuple[str, str]], salt: str, + platform: str, anonymize: bool, replace_bot_name: bool, target_bot_name: Optional[str] = None, @@ -391,7 +477,7 @@ class MessageUtils: anonymous_name = anonymize_mapping[user_id][0] component.target_message_sender_nickname = anonymous_name component.target_message_sender_cardname = anonymous_name - if replace_bot_name and target_bot_name and is_bot_self(user_id): + if replace_bot_name and target_bot_name and is_bot_self(user_id, platform): component.target_message_sender_nickname = target_bot_name component.target_message_sender_cardname = target_bot_name else: @@ -447,10 +533,10 @@ class MessageUtils: # TODO: 这个函数的实现非常临时,后续需要替换为更完善的实现,比如直接从配置文件中读取机器人自己的ID,或者通过API获取机器人自己的信息等 -def is_bot_self(user_id: str) -> bool: +def is_bot_self(user_id: str, platform: str) -> bool: """ 判断用户ID是否是机器人自己 临时方法,后续会替换为更完善的实现 """ - return user_id == "bot_self" + return user_id == "bot_self" and platform == "test_platform"