MessageUtils更新
This commit is contained in:
committed by
SengokuCola
parent
9e2afaf6bc
commit
9fbb733e0a
@@ -203,7 +203,7 @@ def load_message_via_file(monkeypatch):
|
||||
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
|
||||
return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为
|
||||
|
||||
def dummy_is_bot_self(user_id: str) -> bool:
|
||||
def dummy_is_bot_self(user_id: str, platform) -> bool:
|
||||
return user_id == "bot_self"
|
||||
|
||||
def load_utils_via_file(monkeypatch):
|
||||
@@ -212,6 +212,8 @@ def load_utils_via_file(monkeypatch):
|
||||
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
|
||||
math_utils_mod = ModuleType("src.common.utils.math_utils")
|
||||
math_utils_mod.number_to_short_id = dummy_number_to_short_id
|
||||
math_utils_mod.TimestampMode = type("TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"})
|
||||
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: "2024-01-01 12:00:00" # 返回固定的时间字符串
|
||||
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
|
||||
|
||||
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
|
||||
@@ -252,7 +254,7 @@ async def test_build_readable_message_basic(monkeypatch):
|
||||
user_info = UserInfo(user_id="u1", user_nickname="Alice")
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.raw_message = MessageSequence([TextComponent("Hello world")])
|
||||
text, mapping = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True)
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True)
|
||||
assert "[1] Alice说:Hello world" in text
|
||||
assert mapping == {}
|
||||
|
||||
@@ -270,7 +272,7 @@ async def test_build_readable_message_anonymize(monkeypatch):
|
||||
user_info = UserInfo(user_id="u42", user_nickname="Bob")
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.raw_message = MessageSequence([TextComponent("Secret text")])
|
||||
text, mapping = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False)
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False)
|
||||
# 根据实现,original_name 为 user_nickname,因此文本中应包含原始名称
|
||||
assert "XXXXXX说:" in text
|
||||
assert "u42" in mapping
|
||||
@@ -291,7 +293,7 @@ async def test_build_readable_message_replace_bot(monkeypatch):
|
||||
user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot")
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.raw_message = MessageSequence([TextComponent("ping")])
|
||||
text, mapping = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot")
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot")
|
||||
assert "MAIBot说:ping" in text
|
||||
|
||||
|
||||
@@ -309,7 +311,7 @@ async def test_build_readable_message_image_extraction(monkeypatch):
|
||||
msg.session_id = "s_img"
|
||||
msg.raw_message = MessageSequence([img])
|
||||
msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser"))
|
||||
text, mapping = await MessageUtils.build_readable_message([msg], extract_pictures=True)
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True)
|
||||
# 应包含图片描述占位
|
||||
assert "图片1" in text
|
||||
# mapping 不为空(匿名化未开启则为空)
|
||||
@@ -333,7 +335,7 @@ async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(
|
||||
msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot"))
|
||||
msg1.raw_message = MessageSequence([TextComponent("Hi")])
|
||||
msg2.raw_message = MessageSequence([TextComponent("Hello")])
|
||||
text, mapping = await MessageUtils.build_readable_message(
|
||||
text, mapping, _ = await MessageUtils.build_readable_message(
|
||||
[msg1, msg2],
|
||||
anonymize=True,
|
||||
replace_bot_name=True,
|
||||
@@ -361,7 +363,7 @@ async def test_build_readable_message_with_at(monkeypatch):
|
||||
msg.session_id = "s_at"
|
||||
msg.raw_message = MessageSequence([at_comp])
|
||||
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
|
||||
text, mapping = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot")
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot")
|
||||
# 验证主消息和@组件中的用户信息都被处理
|
||||
assert "XXXXXX说:" in text # 主消息用户被匿名化
|
||||
assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化
|
||||
@@ -53,7 +53,7 @@ class ExpressionLearner:
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
readable_message, _ = await MessageUtils.build_readable_message(
|
||||
readable_message, _, _ = await MessageUtils.build_readable_message(
|
||||
self._messages_cache,
|
||||
anonymize=True,
|
||||
show_lineno=True,
|
||||
|
||||
@@ -47,7 +47,7 @@ def number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
|
||||
return short_id
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode) -> str:
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode | str) -> str:
|
||||
"""将时间戳按照指定模式转换为人类可读的格式
|
||||
|
||||
Args:
|
||||
@@ -56,6 +56,11 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode)
|
||||
Returns:
|
||||
str: 转换后的时间字符串
|
||||
"""
|
||||
if isinstance(mode, str):
|
||||
if mode.upper() in TimestampMode.__members__:
|
||||
mode = TimestampMode[mode.upper()]
|
||||
else:
|
||||
raise ValueError(f"不支持的时间戳转换模式: {mode}")
|
||||
if mode in [TimestampMode.NORMAL, TimestampMode.NORMAL_NO_YMD]:
|
||||
return time.strftime(mode.value, time.localtime(timestamp))
|
||||
elif mode == TimestampMode.RELATIVE:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from enum import Enum
|
||||
from maim_message import MessageBase, Seg
|
||||
from typing import List, Tuple, Optional, Dict, TYPE_CHECKING
|
||||
from typing import List, Tuple, Optional, Dict, TYPE_CHECKING, Callable
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
@@ -136,7 +135,7 @@ class MessageUtils:
|
||||
|
||||
@staticmethod
|
||||
def store_message_to_db(message: "SessionMessage"):
|
||||
"""存储消息到数据库"""
|
||||
"""存储消息到数据库,此方法没有update机制"""
|
||||
from src.common.database.database import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
@@ -152,10 +151,12 @@ class MessageUtils:
|
||||
extract_pictures: bool = False,
|
||||
replace_bot_name: bool = False,
|
||||
target_bot_name: Optional[str] = None,
|
||||
timestamp_mode: Optional[TimestampMode] = None,
|
||||
timestamp_mode: Optional[TimestampMode | str] = None,
|
||||
show_message_id_prefix: bool = False,
|
||||
read_mark_time: Optional[float] = None,
|
||||
truncate_message: bool = False,
|
||||
) -> Tuple[str, Dict[str, Tuple[str, str]]]:
|
||||
truncate_func: Optional[Callable[[float], Tuple[Optional[int], str]]] = None,
|
||||
) -> Tuple[str, Dict[str, Tuple[str, str]], List[str]]:
|
||||
"""
|
||||
将消息构建为LLM可读的文本格式
|
||||
|
||||
@@ -168,17 +169,21 @@ class MessageUtils:
|
||||
target_bot_name (Optional[str]): 如果replace_bot_name为True,指定要替换的机器人名称,比如可以把机器人名称替换为“你”
|
||||
timestamp_mode (Optional[TimestampMode]): 时间戳显示模式,默认为None表示不显示时间戳
|
||||
show_message_id_prefix (bool): 是否在每条消息前显示消息ID前缀
|
||||
truncate_message (bool): 是否截断过长的消息文本,避免生成过长的输入给LLM
|
||||
truncate_message (bool): 是否启用消息文本截断功能,截断过长的消息文本
|
||||
truncate_func (Optional[Callable[[float], Tuple[Optional[int], str]]]) 截断函数,接受消息的百分位位置(0-1),返回一个元组(文本长度限制(可为None表不切割), 替换内容)
|
||||
Returns:
|
||||
return (Tuple[str, Dict[str, Tuple[str, str]]]): 构建后的消息文本,以及映射表(匿名ID, 原始名称)
|
||||
return (Tuple[str, Dict[str, Tuple[str, str]], List[str]]): 构建后的消息文本,映射表 {用户ID: (匿名ID, 原始名称)},消息编号列表
|
||||
"""
|
||||
msg_list: List["SessionMessage"] = messages
|
||||
user_id_mapping: Dict[str, Tuple[str, str]] = {} # user_id -> (匿名ID, 原始名称)
|
||||
message_ids: List[str] = [] # 存储消息编号的列表
|
||||
copied: bool = False # 标记是否已经复制过消息列表,避免不必要的复制开销
|
||||
img_map: Optional[Dict[str, Tuple[int, str]]] = None
|
||||
emoji_map: Optional[Dict[str, Tuple[int, str]]] = None
|
||||
if replace_bot_name and not target_bot_name:
|
||||
raise ValueError("当replace_bot_name为True时,必须指定target_bot_name参数")
|
||||
|
||||
# 匿名化和机器人名称处理
|
||||
if anonymize or replace_bot_name:
|
||||
user_id_mapping = {} # 利用弱引用直接传入并得到修改结果
|
||||
anonymous_messages: List["SessionMessage"] = []
|
||||
@@ -198,28 +203,99 @@ class MessageUtils:
|
||||
copied = True
|
||||
|
||||
processed_plain_texts: List[str] = []
|
||||
|
||||
# 将图片提取到内容最前面
|
||||
if extract_pictures:
|
||||
img_map = {} # binary_hash -> (图片ID, 描述信息)
|
||||
emoji_map = {} # binary_hash -> (表情ID, 描述信息)
|
||||
msg_list = [
|
||||
MessageUtils._extract_pictures_from_message(msg, img_map, emoji_map, copied) for msg in msg_list
|
||||
]
|
||||
processed_plain_texts.append("图片信息和表情信息:")
|
||||
processed_plain_texts.extend(f"[图片{img_id}: {desc}]" for img_id, desc in img_map.values())
|
||||
processed_plain_texts.append("") # 图片和表情之间添加一个换行,避免连在一起
|
||||
processed_plain_texts.extend(f"[表情{emoji_id}: {desc}]" for emoji_id, desc in emoji_map.values())
|
||||
processed_plain_texts.append("") # 表情和消息文本之间添加两个换行,避免连在一起
|
||||
processed_plain_texts.extend(("", "聊天记录信息:"))
|
||||
|
||||
lineno_counter = 1
|
||||
for msg in msg_list:
|
||||
msg_count = len(msg_list)
|
||||
read_mark_added_flag: bool = False # 标记是否已经添加过已读标签,确保只添加一次
|
||||
for i, msg in enumerate(msg_list):
|
||||
await msg.process()
|
||||
plain_text: str = msg.processed_plain_text # type: ignore
|
||||
usr_info = msg.message_info.user_info
|
||||
usr_name = usr_info.user_cardname or usr_info.user_nickname or "未知用户"
|
||||
header = f"[{lineno_counter}] {usr_name}说:" if show_lineno else f"{usr_name}说:"
|
||||
lineno_counter += 1
|
||||
if truncate_message: # 消息截断逻辑
|
||||
percentile = i / msg_count
|
||||
if not read_mark_time: # 没有已读标签
|
||||
plain_text = MessageUtils._truncate_message(
|
||||
percentile,
|
||||
plain_text,
|
||||
truncate_func or MessageUtils._default_truncate_func,
|
||||
)
|
||||
elif msg.timestamp.timestamp() < read_mark_time:
|
||||
plain_text = MessageUtils._truncate_message(
|
||||
percentile,
|
||||
plain_text,
|
||||
truncate_func or MessageUtils._default_truncate_func,
|
||||
)
|
||||
elif not read_mark_added_flag:
|
||||
read_mark_added_flag = True
|
||||
processed_plain_texts.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n")
|
||||
header, message_id = MessageUtils._build_line_header(
|
||||
msg,
|
||||
show_message_id_prefix,
|
||||
i + 1,
|
||||
show_lineno,
|
||||
timestamp_mode,
|
||||
)
|
||||
if message_id is not None:
|
||||
message_ids.append(message_id)
|
||||
processed_plain_texts.append("".join([header, plain_text]))
|
||||
|
||||
return "\n".join(processed_plain_texts), user_id_mapping
|
||||
return "\n".join(processed_plain_texts), user_id_mapping, message_ids
|
||||
|
||||
@staticmethod
|
||||
def _build_line_header(
|
||||
message: "SessionMessage",
|
||||
show_message_id_prefix: bool,
|
||||
counter: int,
|
||||
show_lineno: bool,
|
||||
timestamp_mode: Optional[str | TimestampMode] = None,
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
usr_info = message.message_info.user_info
|
||||
usr_name = usr_info.user_cardname or usr_info.user_nickname or "未知用户"
|
||||
header_parts = [f"{usr_name}说:"]
|
||||
message_id = None
|
||||
if show_message_id_prefix:
|
||||
rand_id = f"{counter}{random.randint(10, 99)}"
|
||||
message_id = f"m{rand_id}"
|
||||
header_parts.insert(0, f"[消息ID: {message_id}]")
|
||||
if timestamp_mode:
|
||||
timestamp_str = translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode)
|
||||
header_parts.insert(0, f"[{timestamp_str}]")
|
||||
if show_lineno:
|
||||
header_parts.insert(0, f"[{counter}]")
|
||||
return " ".join(header_parts), message_id
|
||||
|
||||
@staticmethod
|
||||
def _truncate_message(
|
||||
percentile: float, original_content: str, truncate_func: Callable[[float], Tuple[Optional[int], str]]
|
||||
):
|
||||
limit, replacement = truncate_func(percentile)
|
||||
if limit:
|
||||
return f"{original_content[:limit]}{replacement}"
|
||||
else:
|
||||
return original_content
|
||||
|
||||
@staticmethod
|
||||
def _default_truncate_func(percentile: float) -> Tuple[int, str]:
|
||||
"""默认的截断函数,根据消息在消息列表中的位置返回不同的截断长度和替换内容"""
|
||||
if percentile < 0.3:
|
||||
return 400, "......(内容太长了)"
|
||||
elif percentile < 0.5:
|
||||
return 200, "......(内容太长了)"
|
||||
elif percentile < 0.8:
|
||||
return 100, "......(有点记不清了)"
|
||||
else:
|
||||
return 50, "......(记不清了)"
|
||||
|
||||
@staticmethod
|
||||
def _process_usr_info(
|
||||
@@ -232,11 +308,13 @@ class MessageUtils:
|
||||
):
|
||||
"""处理消息中的用户信息,进行匿名化显示"""
|
||||
new_message = message.deepcopy()
|
||||
platform = message.platform
|
||||
new_component_list = [
|
||||
MessageUtils._process_msg_component(
|
||||
component,
|
||||
anonymize_mapping,
|
||||
salt,
|
||||
platform,
|
||||
anonymize,
|
||||
replace_bot_name,
|
||||
target_bot_name,
|
||||
@@ -254,7 +332,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[msg_usr_info.user_id][0]
|
||||
new_message.message_info.user_info.user_nickname = anonymous_name
|
||||
new_message.message_info.user_info.user_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(msg_usr_info.user_id):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(msg_usr_info.user_id, platform):
|
||||
new_message.message_info.user_info.user_nickname = target_bot_name
|
||||
new_message.message_info.user_info.user_cardname = target_bot_name
|
||||
return new_message
|
||||
@@ -264,6 +342,7 @@ class MessageUtils:
|
||||
component: StandardMessageComponents,
|
||||
anonymize_mapping: Dict[str, Tuple[str, str]],
|
||||
salt: str,
|
||||
platform: str,
|
||||
anonymize: bool,
|
||||
replace_bot_name: bool,
|
||||
target_bot_name: Optional[str] = None,
|
||||
@@ -274,6 +353,7 @@ class MessageUtils:
|
||||
component,
|
||||
anonymize_mapping,
|
||||
salt,
|
||||
platform,
|
||||
anonymize,
|
||||
replace_bot_name,
|
||||
target_bot_name,
|
||||
@@ -283,6 +363,7 @@ class MessageUtils:
|
||||
component,
|
||||
anonymize_mapping,
|
||||
salt,
|
||||
platform,
|
||||
anonymize,
|
||||
replace_bot_name,
|
||||
target_bot_name,
|
||||
@@ -292,6 +373,7 @@ class MessageUtils:
|
||||
component,
|
||||
anonymize_mapping,
|
||||
salt,
|
||||
platform,
|
||||
anonymize,
|
||||
replace_bot_name,
|
||||
target_bot_name,
|
||||
@@ -303,6 +385,7 @@ class MessageUtils:
|
||||
component: AtComponent,
|
||||
anonymize_mapping: Dict[str, Tuple[str, str]],
|
||||
salt: str,
|
||||
platform: str,
|
||||
anonymize: bool,
|
||||
replace_bot_name: bool,
|
||||
target_bot_name: Optional[str] = None,
|
||||
@@ -319,7 +402,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[user_id][0]
|
||||
component.target_user_nickname = anonymous_name
|
||||
component.target_user_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id, platform):
|
||||
component.target_user_nickname = target_bot_name
|
||||
component.target_user_cardname = target_bot_name
|
||||
return component
|
||||
@@ -329,6 +412,7 @@ class MessageUtils:
|
||||
component: ForwardNodeComponent,
|
||||
anonymize_mapping: Dict[str, Tuple[str, str]],
|
||||
salt: str,
|
||||
platform: str,
|
||||
anonymize: bool,
|
||||
replace_bot_name: bool,
|
||||
target_bot_name: Optional[str] = None,
|
||||
@@ -354,7 +438,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[user_id][0]
|
||||
comp.user_nickname = anonymous_name
|
||||
comp.user_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id, platform):
|
||||
comp.user_nickname = target_bot_name
|
||||
comp.user_cardname = target_bot_name
|
||||
comp.content = [ # 递归处理转发消息中的组件
|
||||
@@ -362,6 +446,7 @@ class MessageUtils:
|
||||
c,
|
||||
anonymize_mapping,
|
||||
salt,
|
||||
platform,
|
||||
anonymize,
|
||||
replace_bot_name,
|
||||
target_bot_name,
|
||||
@@ -375,6 +460,7 @@ class MessageUtils:
|
||||
component: ReplyComponent,
|
||||
anonymize_mapping: Dict[str, Tuple[str, str]],
|
||||
salt: str,
|
||||
platform: str,
|
||||
anonymize: bool,
|
||||
replace_bot_name: bool,
|
||||
target_bot_name: Optional[str] = None,
|
||||
@@ -391,7 +477,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[user_id][0]
|
||||
component.target_message_sender_nickname = anonymous_name
|
||||
component.target_message_sender_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id, platform):
|
||||
component.target_message_sender_nickname = target_bot_name
|
||||
component.target_message_sender_cardname = target_bot_name
|
||||
else:
|
||||
@@ -447,10 +533,10 @@ class MessageUtils:
|
||||
|
||||
|
||||
# TODO: 这个函数的实现非常临时,后续需要替换为更完善的实现,比如直接从配置文件中读取机器人自己的ID,或者通过API获取机器人自己的信息等
|
||||
def is_bot_self(user_id: str) -> bool:
|
||||
def is_bot_self(user_id: str, platform: str) -> bool:
|
||||
"""
|
||||
判断用户ID是否是机器人自己
|
||||
|
||||
临时方法,后续会替换为更完善的实现
|
||||
"""
|
||||
return user_id == "bot_self"
|
||||
return user_id == "bot_self" and platform == "test_platform"
|
||||
|
||||
Reference in New Issue
Block a user