feat:使用通用配置其他平台的id,一处无关代码

This commit is contained in:
SengokuCola
2025-10-14 12:08:21 +08:00
parent ba4465ffbc
commit 10b16947a5
5 changed files with 83 additions and 248 deletions

View File

@@ -1,7 +1,7 @@
import re
import traceback
from typing import Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
@@ -17,34 +17,6 @@ if TYPE_CHECKING:
logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
"""
if message.is_picid or message.is_emoji:
return 0.0, []
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
# 保留适配器/上游直接标记的提及信号,避免被覆盖
if getattr(message, "is_mentioned", False):
is_mentioned = True
# interested_rate = 0.0
keywords = []
message.interest_value = 1
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
return 1, keywords
class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
@@ -70,12 +42,16 @@ class HeartFCMessageReceiver:
userinfo = message.message_info.user_info
chat = message.chat_stream
# 2. 兴趣度计算与更新
_, keywords = await _calculate_interest(message)
# 2. 计算at信息
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
print(f"is_mentioned: {is_mentioned}, is_at: {is_at}, reply_probability_boost: {reply_probability_boost}")
message.is_mentioned = is_mentioned
message.is_at = is_at
message.reply_probability_boost = reply_probability_boost
await self.storage.store_message(message, chat)
_heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"

View File

@@ -529,7 +529,7 @@ class DefaultReplyer:
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中
这是上述中你和{sender}的对话摘要,内容从上面的对话中截取,便于你理解
{core_dialogue_prompt_str}
--------------------------------
"""

View File

@@ -30,26 +30,56 @@ def is_english_letter(char: str) -> bool:
return "a" <= char.lower() <= "z"
def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
logger.debug(f"result: {result}")
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
"""解析 platforms 列表,返回平台到账号的映射
Args:
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
Returns:
字典,键为平台名,值为账号
"""
result = {}
for platform_entry in platforms:
if ":" in platform_entry:
platform_name, account = platform_entry.split(":", 1)
result[platform_name.strip()] = account.strip()
return result
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
"""根据当前平台获取对应的账号
Args:
platform: 当前消息的平台
platform_accounts: 从 platforms 列表解析的平台账号映射
qq_account: QQ 账号(兼容旧配置)
Returns:
当前平台对应的账号
"""
if platform == "qq":
return qq_account
elif platform == "telegram":
# 优先使用 tg其次使用 telegram
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
else:
# 其他平台直接使用平台名作为键
return platform_accounts.get(platform, "")
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
"""检查消息是否提到了机器人(多平台实现)"""
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or ""
qq_id = str(getattr(global_config.bot, "qq_account", "") or "")
tg_id = str(getattr(global_config.bot, "telegram_account", "") or "")
tg_uname = str(getattr(global_config.bot, "telegram_username", "") or "")
# 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or []
platform_accounts = parse_platform_accounts(platforms_list)
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
# 获取当前平台对应的账号
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
nickname = str(global_config.bot.nickname or "")
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
@@ -94,32 +124,30 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
is_at = True
is_mentioned = True
# 4) 文本层面的 @ 检测(多平台)
# QQ: @<name:qq_id>
if qq_id and re.search(rf"@<(.+?):{re.escape(qq_id)}>", text):
is_at = True
is_mentioned = True
# Telegram: @username
if platform == "telegram" and tg_uname:
if re.search(rf"@{re.escape(tg_uname)}(\b|$)", text, flags=re.IGNORECASE):
is_at = True
is_mentioned = True
# 4) 统一的 @ 检测逻辑
if current_account and not is_at and not is_mentioned:
if platform == "qq":
# QQ 格式: @<name:qq_id>
if re.search(rf"@<(.+?):{re.escape(current_account)}>", text):
is_at = True
is_mentioned = True
else:
# 其他平台格式: @username 或 @account
if re.search(rf"@{re.escape(current_account)}(\b|$)", text, flags=re.IGNORECASE):
is_at = True
is_mentioned = True
# 5) 回复机器人检测:
# a) 通用显示文本:包含 “(你)” 或 “(你)” 的回复格式
if re.search(r"\[回复 .*?\(你\)", text) or re.search(r"\[回复 .*?(你):", text):
is_mentioned = True
# b) 兼容 ID 形式QQ与Telegram
if qq_id and (
re.search(rf"\[回复 (.+?)\({re.escape(qq_id)}\)(.+?)\],说:", text)
or re.search(rf"\[回复<(.+?)(?=:{re.escape(qq_id)}>)\:{re.escape(qq_id)}>(.+?)\],说:", text)
):
is_mentioned = True
if tg_id and (
re.search(rf"\[回复 (.+?)\({re.escape(tg_id)}\)(.+?)\],说:", text)
or re.search(rf"\[回复<(.+?)(?=:{re.escape(tg_id)}>)\:{re.escape(tg_id)}>(.+?)\],说:", text)
):
is_mentioned = True
# 5) 统一的回复检测逻辑
if not is_mentioned:
# 通用回复格式:包含 "(你)" 或 "(你)"
if re.search(r"\[回复 .*?\(你\)", text) or re.search(r"\[回复 .*?(你):", text):
is_mentioned = True
# ID 形式的回复检测
elif current_account:
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\)(.+?)\],说:", text):
is_mentioned = True
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>(.+?)\],说:", text):
is_mentioned = True
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
if not is_mentioned and keywords:
@@ -157,45 +185,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
return embedding
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
who_chat_in_group = []
for db_msg in recent_messages:
# user_info = UserInfo.from_dict(
# {
# "platform": msg_db_data["user_platform"],
# "user_id": msg_db_data["user_id"],
# "user_nickname": msg_db_data["user_nickname"],
# "user_cardname": msg_db_data.get("user_cardname", ""),
# }
# )
# if (
# (user_info.platform, user_info.user_id) != sender
# and user_info.user_id != global_config.bot.qq_account
# and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
# and len(who_chat_in_group) < 5
# ): # 排除重复排除消息发送者排除bot限制加载的关系数目
# who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
if (
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and db_msg.user_info.user_id != global_config.bot.qq_account
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append(
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
)
return who_chat_in_group
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
@@ -452,42 +441,6 @@ def calculate_typing_time(
return total_time # 加上回车时间
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
"""使用简单的余弦相似度计算文本相似度"""
# 将输入文本转换为词频向量
text_vector = text_to_vector(text)
# 计算每个主题的相似度
similarities = []
for topic in topics:
topic_vector = text_to_vector(topic)
# 获取所有唯一词
all_words = set(text_vector.keys()) | set(topic_vector.keys())
# 构建向量
v1 = [text_vector.get(word, 0) for word in all_words]
v2 = [topic_vector.get(word, 0) for word in all_words]
# 计算相似度
similarity = cosine_similarity(v1, v2)
similarities.append((topic, similarity))
# 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
@@ -565,47 +518,6 @@ def get_western_ratio(paragraph):
return western_count / len(alnum_chars)
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
"""计算两个时间点之间的消息数量和文本总长度
Args:
start_time (float): 起始时间戳 (不包含)
end_time (float): 结束时间戳 (包含)
stream_id (str): 聊天流ID
Returns:
tuple[int, int]: (消息数量, 文本总长度)
"""
count = 0
total_length = 0
# 参数校验 (可选但推荐)
if start_time >= end_time:
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
return 0, 0
if not stream_id:
logger.error("stream_id 不能为空")
return 0, 0
# 使用message_repository中的count_messages和find_messages函数
# 构建查询条件
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 先获取消息数量
count = count_messages(filter_query)
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.processed_plain_text or "") for msg in messages)
return count, total_length
except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
@@ -740,65 +652,6 @@ def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, Data
return result
# def assign_message_ids_flexible(
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
# ) -> list:
# """
# 为消息列表中的每个消息分配唯一的简短随机ID增强版
# Args:
# messages: 消息列表
# prefix: ID前缀默认为"msg"
# id_length: ID的总长度不包括前缀默认为6
# use_timestamp: 是否在ID中包含时间戳默认为False
# Returns:
# 包含 {'id': str, 'message': any} 格式的字典列表
# """
# result = []
# used_ids = set()
# for i, message in enumerate(messages):
# # 生成唯一的ID
# while True:
# if use_timestamp:
# # 使用时间戳的后几位 + 随机字符
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
# remaining_length = id_length - 3
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
# else:
# # 使用索引 + 随机字符
# index_str = str(i + 1)
# remaining_length = max(1, id_length - len(index_str))
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{index_str}{random_chars}"
# if message_id not in used_ids:
# used_ids.add(message_id)
# break
# result.append({"id": message_id, "message": message})
# return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
"""

View File

@@ -27,6 +27,9 @@ class BotConfig(ConfigBase):
nickname: str
"""昵称"""
platforms: list[str] = field(default_factory=lambda: [])
"""其他平台列表"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""

View File

@@ -1,5 +1,5 @@
[inner]
version = "6.18.3"
version = "6.18.4"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -14,6 +14,9 @@ version = "6.18.3"
[bot]
platform = "qq"
qq_account = "1145141919810" # 麦麦的QQ账号
platforms = ["wx:114514","xx:1919810"] # 麦麦的其他平台账号
nickname = "麦麦" # 麦麦的昵称
alias_names = ["麦叠", "牢麦"] # 麦麦的别名