From 10b16947a5f9535c76ca5769e03697df4d9d3023 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 14 Oct 2025 12:08:21 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E4=BD=BF=E7=94=A8=E9=80=9A?= =?UTF-8?q?=E7=94=A8=E9=85=8D=E7=BD=AE=E5=85=B6=E4=BB=96=E5=B9=B3=E5=8F=B0?= =?UTF-8?q?=E7=9A=84id,=E4=B8=80=E5=A4=84=E6=97=A0=E5=85=B3=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../heart_flow/heartflow_message_processor.py | 40 +-- src/chat/replyer/group_generator.py | 2 +- src/chat/utils/utils.py | 281 +++++------------- src/config/official_configs.py | 3 + template/bot_config_template.toml | 5 +- 5 files changed, 83 insertions(+), 248 deletions(-) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 4247d02c..822d05de 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -1,7 +1,7 @@ import re import traceback -from typing import Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage @@ -17,34 +17,6 @@ if TYPE_CHECKING: logger = get_logger("chat") - -async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: - """计算消息的兴趣度 - - Args: - message: 待处理的消息对象 - - Returns: - Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词) - """ - if message.is_picid or message.is_emoji: - return 0.0, [] - - is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message) - # 保留适配器/上游直接标记的提及信号,避免被覆盖 - if getattr(message, "is_mentioned", False): - is_mentioned = True - # interested_rate = 0.0 - keywords = [] - - message.interest_value = 1 - message.is_mentioned = is_mentioned - message.is_at = is_at - message.reply_probability_boost = reply_probability_boost - - return 1, keywords - - class HeartFCMessageReceiver: """心流处理器,负责处理接收到的消息并计算兴趣度""" @@ -70,12 +42,16 @@ class HeartFCMessageReceiver: userinfo = message.message_info.user_info chat = message.chat_stream - # 2. 兴趣度计算与更新 - _, keywords = await _calculate_interest(message) + # 2. 计算at信息 + is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message) + print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}") + message.is_mentioned = is_mentioned + message.is_at = is_at + message.reply_probability_boost = reply_probability_boost await self.storage.store_message(message, chat) - _heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore + await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore # 3. 日志记录 mes_name = chat.group_info.group_name if chat.group_info else "私聊" diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 88935da7..3a24f04f 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -529,7 +529,7 @@ class DefaultReplyer: show_actions=True, ) core_dialogue_prompt = f"""-------------------------------- -这是你和{sender}的对话,你们正在交流中: +这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解: {core_dialogue_prompt_str} -------------------------------- """ diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 3011c865..ce3eab08 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -30,26 +30,56 @@ def is_english_letter(char: str) -> bool: return "a" <= char.lower() <= "z" -def db_message_to_str(message_dict: dict) -> str: - logger.debug(f"message_dict: {message_dict}") - time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) - try: - name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}" - except Exception: - name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" - content = message_dict.get("processed_plain_text", "") - result = f"[{time_str}] {name}: {content}\n" - logger.debug(f"result: {result}") +def parse_platform_accounts(platforms: list[str]) -> dict[str, str]: + """解析 platforms 列表,返回平台到账号的映射 + + Args: + platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"] + + Returns: + 字典,键为平台名,值为账号 + """ + result = {} + for platform_entry in platforms: + if ":" in platform_entry: + platform_name, account = platform_entry.split(":", 1) + result[platform_name.strip()] = account.strip() return result +def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str: + """根据当前平台获取对应的账号 + + Args: + platform: 当前消息的平台 + platform_accounts: 从 platforms 列表解析的平台账号映射 + qq_account: QQ 账号(兼容旧配置) + + Returns: + 当前平台对应的账号 + """ + if platform == "qq": + return qq_account + elif platform == "telegram": + # 优先使用 tg,其次使用 telegram + return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "") + else: + # 其他平台直接使用平台名作为键 + return platform_accounts.get(platform, "") + + def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]: - """检查消息是否提到了机器人(多平台实现)""" + """检查消息是否提到了机器人(统一多平台实现)""" text = message.processed_plain_text or "" platform = getattr(message.message_info, "platform", "") or "" - qq_id = str(getattr(global_config.bot, "qq_account", "") or "") - tg_id = str(getattr(global_config.bot, "telegram_account", "") or "") - tg_uname = str(getattr(global_config.bot, "telegram_username", "") or "") + + # 获取各平台账号 + platforms_list = getattr(global_config.bot, "platforms", []) or [] + platform_accounts = parse_platform_accounts(platforms_list) + qq_account = str(getattr(global_config.bot, "qq_account", "") or "") + + # 获取当前平台对应的账号 + current_account = get_current_platform_account(platform, platform_accounts, qq_account) nickname = str(global_config.bot.nickname or "") alias_names = list(getattr(global_config.bot, "alias_names", []) or []) @@ -94,32 +124,30 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float is_at = True is_mentioned = True - # 4) 文本层面的 @ 检测(多平台) - # QQ: @ - if qq_id and re.search(rf"@<(.+?):{re.escape(qq_id)}>", text): - is_at = True - is_mentioned = True - # Telegram: @username - if platform == "telegram" and tg_uname: - if re.search(rf"@{re.escape(tg_uname)}(\b|$)", text, flags=re.IGNORECASE): - is_at = True - is_mentioned = True + # 4) 统一的 @ 检测逻辑 + if current_account and not is_at and not is_mentioned: + if platform == "qq": + # QQ 格式: @ + if re.search(rf"@<(.+?):{re.escape(current_account)}>", text): + is_at = True + is_mentioned = True + else: + # 其他平台格式: @username 或 @account + if re.search(rf"@{re.escape(current_account)}(\b|$)", text, flags=re.IGNORECASE): + is_at = True + is_mentioned = True - # 5) 回复机器人检测: - # a) 通用显示文本:包含 “(你)” 或 “(你)” 的回复格式 - if re.search(r"\[回复 .*?\(你\):", text) or re.search(r"\[回复 .*?(你):", text): - is_mentioned = True - # b) 兼容 ID 形式(QQ与Telegram) - if qq_id and ( - re.search(rf"\[回复 (.+?)\({re.escape(qq_id)}\):(.+?)\],说:", text) - or re.search(rf"\[回复<(.+?)(?=:{re.escape(qq_id)}>)\:{re.escape(qq_id)}>:(.+?)\],说:", text) - ): - is_mentioned = True - if tg_id and ( - re.search(rf"\[回复 (.+?)\({re.escape(tg_id)}\):(.+?)\],说:", text) - or re.search(rf"\[回复<(.+?)(?=:{re.escape(tg_id)}>)\:{re.escape(tg_id)}>:(.+?)\],说:", text) - ): - is_mentioned = True + # 5) 统一的回复检测逻辑 + if not is_mentioned: + # 通用回复格式:包含 "(你)" 或 "(你)" + if re.search(r"\[回复 .*?\(你\):", text) or re.search(r"\[回复 .*?(你):", text): + is_mentioned = True + # ID 形式的回复检测 + elif current_account: + if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text): + is_mentioned = True + elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text): + is_mentioned = True # 6) 名称/别名 提及(去除 @/回复标记后再匹配) if not is_mentioned and keywords: @@ -157,45 +185,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]] return embedding -def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: - # 获取当前群聊记录内发言的人 - filter_query = {"chat_id": chat_stream_id} - sort_order = [("time", -1)] - recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit) - - if not recent_messages: - return [] - - who_chat_in_group = [] - for db_msg in recent_messages: - # user_info = UserInfo.from_dict( - # { - # "platform": msg_db_data["user_platform"], - # "user_id": msg_db_data["user_id"], - # "user_nickname": msg_db_data["user_nickname"], - # "user_cardname": msg_db_data.get("user_cardname", ""), - # } - # ) - # if ( - # (user_info.platform, user_info.user_id) != sender - # and user_info.user_id != global_config.bot.qq_account - # and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group - # and len(who_chat_in_group) < 5 - # ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目 - # who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname)) - if ( - (db_msg.user_info.platform, db_msg.user_info.user_id) != sender - and db_msg.user_info.user_id != global_config.bot.qq_account - and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) - not in who_chat_in_group - and len(who_chat_in_group) < 5 - ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目 - who_chat_in_group.append( - (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) - ) - - return who_chat_in_group - def split_into_sentences_w_remove_punctuation(text: str) -> list[str]: """将文本分割成句子,并根据概率合并 @@ -452,42 +441,6 @@ def calculate_typing_time( return total_time # 加上回车时间 -def cosine_similarity(v1, v2): - """计算余弦相似度""" - dot_product = np.dot(v1, v2) - norm1 = np.linalg.norm(v1) - norm2 = np.linalg.norm(v2) - return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2) - - -def text_to_vector(text): - """将文本转换为词频向量""" - # 分词 - words = jieba.lcut(text) - return Counter(words) - - -def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: - """使用简单的余弦相似度计算文本相似度""" - # 将输入文本转换为词频向量 - text_vector = text_to_vector(text) - - # 计算每个主题的相似度 - similarities = [] - for topic in topics: - topic_vector = text_to_vector(topic) - # 获取所有唯一词 - all_words = set(text_vector.keys()) | set(topic_vector.keys()) - # 构建向量 - v1 = [text_vector.get(word, 0) for word in all_words] - v2 = [topic_vector.get(word, 0) for word in all_words] - # 计算相似度 - similarity = cosine_similarity(v1, v2) - similarities.append((topic, similarity)) - - # 按相似度降序排序并返回前k个 - return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] - def truncate_message(message: str, max_length=20) -> str: """截断消息,使其不超过指定长度""" @@ -565,47 +518,6 @@ def get_western_ratio(paragraph): return western_count / len(alnum_chars) -def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]: - """计算两个时间点之间的消息数量和文本总长度 - - Args: - start_time (float): 起始时间戳 (不包含) - end_time (float): 结束时间戳 (包含) - stream_id (str): 聊天流ID - - Returns: - tuple[int, int]: (消息数量, 文本总长度) - """ - count = 0 - total_length = 0 - - # 参数校验 (可选但推荐) - if start_time >= end_time: - # logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0") - return 0, 0 - if not stream_id: - logger.error("stream_id 不能为空") - return 0, 0 - - # 使用message_repository中的count_messages和find_messages函数 - - # 构建查询条件 - filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}} - - try: - # 先获取消息数量 - count = count_messages(filter_query) - - # 获取消息内容计算总长度 - messages = find_messages(message_filter=filter_query) - total_length = sum(len(msg.processed_plain_text or "") for msg in messages) - - return count, total_length - - except Exception as e: - logger.error(f"计算消息数量时发生意外错误: {e}") - return 0, 0 - def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str: # sourcery skip: merge-comparisons, merge-duplicate-blocks, switch @@ -740,65 +652,6 @@ def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, Data return result -# def assign_message_ids_flexible( -# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False -# ) -> list: -# """ -# 为消息列表中的每个消息分配唯一的简短随机ID(增强版) - -# Args: -# messages: 消息列表 -# prefix: ID前缀,默认为"msg" -# id_length: ID的总长度(不包括前缀),默认为6 -# use_timestamp: 是否在ID中包含时间戳,默认为False - -# Returns: -# 包含 {'id': str, 'message': any} 格式的字典列表 -# """ -# result = [] -# used_ids = set() - -# for i, message in enumerate(messages): -# # 生成唯一的ID -# while True: -# if use_timestamp: -# # 使用时间戳的后几位 + 随机字符 -# timestamp_suffix = str(int(time.time() * 1000))[-3:] -# remaining_length = id_length - 3 -# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length)) -# message_id = f"{prefix}{timestamp_suffix}{random_chars}" -# else: -# # 使用索引 + 随机字符 -# index_str = str(i + 1) -# remaining_length = max(1, id_length - len(index_str)) -# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length)) -# message_id = f"{prefix}{index_str}{random_chars}" - -# if message_id not in used_ids: -# used_ids.add(message_id) -# break - -# result.append({"id": message_id, "message": message}) - -# return result - - -# 使用示例: -# messages = ["Hello", "World", "Test message"] -# -# # 基础版本 -# result1 = assign_message_ids(messages) -# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}] -# -# # 增强版本 - 自定义前缀和长度 -# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8) -# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}] -# -# # 增强版本 - 使用时间戳 -# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True) -# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}] - - def parse_keywords_string(keywords_input) -> list[str]: # sourcery skip: use-contextlib-suppress """ diff --git a/src/config/official_configs.py b/src/config/official_configs.py index df616a64..641b287a 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -27,6 +27,9 @@ class BotConfig(ConfigBase): nickname: str """昵称""" + + platforms: list[str] = field(default_factory=lambda: []) + """其他平台列表""" alias_names: list[str] = field(default_factory=lambda: []) """别名列表""" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index b88077ca..f51c5203 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.18.3" +version = "6.18.4" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -14,6 +14,9 @@ version = "6.18.3" [bot] platform = "qq" qq_account = "1145141919810" # 麦麦的QQ账号 + +platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号 + nickname = "麦麦" # 麦麦的昵称 alias_names = ["麦叠", "牢麦"] # 麦麦的别名